[Model] Use merge_by_field_config for MM models (D-F) (#26076)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-02 23:17:35 +08:00
committed by GitHub
parent 7d6fb905d9
commit cc253b73d3
4 changed files with 102 additions and 180 deletions

View File

@@ -20,8 +20,7 @@ from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models.transformers import replace_linear_class
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalUUIDDict,
NestedTensors)
MultiModalKwargsItems, MultiModalUUIDDict)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -40,7 +39,7 @@ from vllm.utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
# The image token id may be various
@@ -50,15 +49,15 @@ _IMAGE_TOKEN = "<image>"
class DeepseekVL2ImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- bnp: Batch size * number of images * number of patches
- p: Number of patches
- c: Number of channels (3)
- h: Height of each image
- w: Width of each image
"""
type: Literal["pixel_values"]
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
data: Annotated[torch.Tensor,
TensorShape("bnp", 3, "h", "w", dynamic_dims={"bnp"})]
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
@@ -228,12 +227,8 @@ class DeepseekVL2MultiModalProcessor(
tok_kwargs=tok_kwargs,
)
pixel_values = processed_outputs["pixel_values"]
# split pixel values into patches corresponding to each image
images_spatial_crop = processed_outputs["images_spatial_crop"]
patches_per_image = [x.prod().item() + 1 for x in images_spatial_crop]
pixel_values = pixel_values.split(patches_per_image)
processed_outputs["pixel_values"] = pixel_values
processed_outputs["num_patches"] = (
processed_outputs["images_spatial_crop"].prod(-1) + 1)
return processed_outputs
@@ -242,8 +237,11 @@ class DeepseekVL2MultiModalProcessor(
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_patches),
images_spatial_crop=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
@@ -318,6 +316,7 @@ class DeepseekVL2MultiModalProcessor(
info=DeepseekVL2ProcessingInfo,
dummy_inputs=DeepseekVL2DummyInputsBuilder)
class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={
"language.": "language_model.",
@@ -460,37 +459,30 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
if pixel_values is not None:
expected_h = expected_w = self.vision_config.image_size
return DeepseekVL2ImagePixelInputs(type="pixel_values",
data=flatten_bn(pixel_values),
images_spatial_crop=flatten_bn(
images_spatial_crop,
concat=True),
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
return DeepseekVL2ImagePixelInputs(
type="pixel_values",
data=pixel_values,
images_spatial_crop=images_spatial_crop,
resolve_bindings={
"h": expected_h,
"w": expected_w,
})
if image_embeds is not None:
return DeepseekVL2VImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _pixel_values_to_embedding(
self,
pixel_values: NestedTensors,
pixel_values: torch.Tensor,
images_spatial_crop: torch.Tensor,
) -> NestedTensors:
# Pixel_values: n_image * batch_size * [patch_per_img, 3, height, width]
total_tiles = [x for x in pixel_values]
# [batch_all_tiles, 3, height, width]
total_tiles = torch.cat(total_tiles, dim=0)
) -> list[torch.Tensor]:
# [batch_all_tiles, vit_seq_len, c]
images_feature = self.vision.forward_features(total_tiles)
images_feature = self.vision.forward_features(pixel_values)
# [batch_all_tiles, hw, D]
images_embeds = self.projector(images_feature)
@@ -573,7 +565,7 @@ class DeepseekVLV2ForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return vision_embeddings
def _process_image_input(
self, image_input: DeepseekVL2ImageInputs) -> torch.Tensor:
self, image_input: DeepseekVL2ImageInputs) -> list[torch.Tensor]:
if image_input["type"] == "image_embeds":
image_data = image_input["data"]
if is_list_of(image_data, torch.Tensor):