[Model] Use merge_by_field_config for MM models (D-F) (#26076)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
||||
from vllm.model_executor.models.transformers import replace_linear_class
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, MultiModalUUIDDict,
|
||||
NestedTensors)
|
||||
MultiModalKwargsItems, MultiModalUUIDDict)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@@ -40,7 +39,7 @@ from vllm.utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
# The image token id may be various
|
||||
@@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
|
||||
class DeepseekVL2ImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- bnp: Batch size * number of images * number of patches
|
||||
- p: Number of patches
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
|
||||
data: Annotated[torch.Tensor,
|
||||
TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
|
||||
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||
|
||||
|
||||
@@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
pixel_values = processed_outputs["pixel_values"]
|
||||
# split pixel values into patches corresponding to each image
|
||||
images_spatial_crop = processed_outputs["images_spatial_crop"]
|
||||
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
|
||||
pixel_values = pixel_values.split(patches_per_image)
|
||||
processed_outputs["pixel_values"] = pixel_values
|
||||
processed_outputs["num_patches"] = (
|
||||
processed_outputs["images_spatial_crop"].prod(-1) + 1)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
num_patches = hf_inputs.get("num_patches", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_patches),
|
||||
images_spatial_crop=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
@@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
|
||||
info=DeepseekVL2ProcessingInfo,
|
||||
dummy_inputs=DeepseekVL2DummyInputsBuilder)
|
||||
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
|
||||
"language.": "language_model.",
|
||||
@@ -460,37 +459,30 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
if pixel_values is not None:
|
||||
expected_h = expected_w = self.vision_config.image_size
|
||||
return DeepseekVL2ImagePixelInputs(type="pixel_values",
|
||||
data=flatten_bn(pixel_values),
|
||||
images_spatial_crop=flatten_bn(
|
||||
images_spatial_crop,
|
||||
concat=True),
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
})
|
||||
return DeepseekVL2ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=pixel_values,
|
||||
images_spatial_crop=images_spatial_crop,
|
||||
resolve_bindings={
|
||||
"h": expected_h,
|
||||
"w": expected_w,
|
||||
})
|
||||
|
||||
if image_embeds is not None:
|
||||
return DeepseekVL2VImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=flatten_bn(image_embeds),
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _pixel_values_to_embedding(
|
||||
self,
|
||||
pixel_values: NestedTensors,
|
||||
pixel_values: torch.Tensor,
|
||||
images_spatial_crop: torch.Tensor,
|
||||
) -> NestedTensors:
|
||||
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
|
||||
total_tiles = [x for x in pixel_values]
|
||||
|
||||
# [batch_all_tiles, 3, height, width]
|
||||
total_tiles = torch.cat(total_tiles, dim=0)
|
||||
|
||||
) -> list[torch.Tensor]:
|
||||
# [batch_all_tiles, vit_seq_len, c]
|
||||
images_feature = self.vision.forward_features(total_tiles)
|
||||
images_feature = self.vision.forward_features(pixel_values)
|
||||
|
||||
# [batch_all_tiles, hw, D]
|
||||
images_embeds = self.projector(images_feature)
|
||||
@@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return vision_embeddings
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
|
||||
self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
image_data = image_input["data"]
|
||||
if is_list_of(image_data, torch.Tensor):
|
||||
|
||||
Reference in New Issue
Block a user