[Model] Use merge_by_field_config for MM models (M-N) (#26710)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-14 01:27:01 +08:00
committed by GitHub
parent e3b90c1ba2
commit afc47e4de7
11 changed files with 127 additions and 331 deletions

View File

@@ -11,7 +11,7 @@ import copy
import warnings
from abc import ABC, abstractmethod
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, TypedDict, TypeVar
from typing import Annotated, Any, Literal, TypeAlias, TypeVar
import numpy.typing as npt
import torch
@@ -40,7 +40,6 @@ from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.nemotron_h import NemotronHForCausalLM
from vllm.model_executor.models.radio import RadioModel
from vllm.model_executor.models.utils import (
flatten_bn,
init_vllm_registered_model,
maybe_prefix,
)
@@ -96,31 +95,35 @@ MAX_FRAMES = 16
DEFAULT_NUM_TILES = 12
class NanoNemotronVLImagePixelInputs(TypedDict):
class NanoNemotronVLImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- bnp: Batch size * number of images * (1 + num_patches)
- c: Number of channels (3)
- h: Height of each image patch
- w: Width of each image patch
"""
type: Literal["pixel_values"]
pixel_values_flat: torch.Tensor
pixel_values_flat: Annotated[torch.Tensor, TensorShape("bnp", 3, "h", "w")]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class NanoNemotronVLImageEmbeddingInputs(TensorSchema):
"""
Shape:
`(batch_size * num_images * (1 + num_patches), num_channels, height, width)`
Dimensions:
- n: Number of images
- f: Total image feature size
- h: Hidden size (must match the hidden size of language model backbone)
"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
class NanoNemotronVLImageEmbeddinInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor | list[torch.Tensor]
"""
A tensor of shape `(num_images, total_image_feature_size, hidden_size)`
or a list of tensors of shape `(total_image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
data: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("n", "f", "h")]
NanoNemotronVLImageInputs: TypeAlias = (
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddinInputs
NanoNemotronVLImagePixelInputs | NanoNemotronVLImageEmbeddingInputs
)
@@ -710,37 +713,12 @@ class NanoNemotronVLProcessingInfo(BaseNanoNemotronVLProcessingInfo):
class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
"""Basic image-only MultiModalProcessor for InternVL-style models."""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_token_id = hf_processor.image_token_id
# Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
@@ -748,7 +726,6 @@ class NanoNemotronBaseVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
),
image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.shared("image", num_images),
)
def _get_prompt_updates(
@@ -814,25 +791,6 @@ class NanoNemotronVLMultiModalProcessor(
):
"""MultiModalProcessor extended for video support"""
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._call_hf_processor(
prompt, mm_data, mm_kwargs, tok_kwargs
)
hf_processor = self.info.get_hf_processor(**mm_kwargs)
if (
self.info.supports_video
and (video_token_id := hf_processor.video_token_id) is not None
):
processed_outputs["video_token_id"] = torch.tensor(video_token_id)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@@ -841,13 +799,12 @@ class NanoNemotronVLMultiModalProcessor(
image_fields = super()._get_mm_fields_config(hf_inputs, hf_processor_mm_kwargs)
if self.info.supports_video:
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
num_videos = len(video_num_patches)
video_fields = dict(
pixel_values_flat_video=MultiModalFieldConfig.flat_from_sizes(
"video", video_num_patches
),
video_num_patches=MultiModalFieldConfig.batched("video"),
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
)
else:
video_fields = {}
@@ -999,6 +956,8 @@ class NanoNemotronVLDummyInputsBuilder(
class NemotronH_Nano_VL_V2(
nn.Module, HasInnerState, IsHybrid, SupportsMultiModal, SupportsMultiModalPruning
):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
@@ -1051,8 +1010,6 @@ class NemotronH_Nano_VL_V2(
)
self.mlp1 = self.mlp1.to(self.language_model.config.torch_dtype)
self.img_context_token_id = None
self.video_context_token_id = None
self.config = config
self.model_config = vllm_config.model_config
@@ -1106,37 +1063,12 @@ class NemotronH_Nano_VL_V2(
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}"
)
return NanoNemotronVLImageEmbeddinInputs(
return NanoNemotronVLImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
assert isinstance(image_token_id, torch.Tensor)
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values_flat is not None:
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat)}"
)
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}"
)
pixel_values_flat = flatten_bn(pixel_values_flat, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
return NanoNemotronVLImagePixelInputs(
type="pixel_values",
pixel_values_flat=pixel_values_flat,
@@ -1285,28 +1217,10 @@ class NemotronH_Nano_VL_V2(
if video_embeds is not None:
return NanoNemotronVLVideoEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
data=video_embeds,
)
video_token_id = kwargs["video_token_id"]
assert isinstance(video_token_id, torch.Tensor)
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}"
)
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}"
)
pixel_values_flat_video = flatten_bn(pixel_values_flat_video, concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
expected_h = expected_w = self.config.force_image_size
resolve_bindings = {"h": expected_h, "w": expected_w}