[Model] Use merge_by_field_config for MM models (InternVL family) (#26153)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-03 16:59:06 +08:00
committed by GitHub
parent 3e70e3d4d5
commit f9a8084e48
9 changed files with 84 additions and 182 deletions

View File

@@ -25,7 +25,7 @@ from vllm.model_executor.models.interns1_vit import InternS1VisionModel
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, NestedTensors)
MultiModalKwargsItems)
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
ImageSize, MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -39,7 +39,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
@@ -304,7 +304,7 @@ class InternS1MultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> Mapping[str, NestedTensors]:
) -> BatchFeature:
mm_data = dict(mm_data)
videos = mm_data.pop("videos", [])
images = mm_data.pop("images", [])
@@ -342,7 +342,7 @@ class InternS1MultiModalProcessor(
image_placeholder, 1)
num_patches = [len(item) for item in image_pixel_values]
image_outputs: dict[str, NestedTensors] = {
image_outputs = {
"pixel_values": torch.concat(image_pixel_values),
"image_num_patches": torch.tensor(num_patches),
"image_token_id": torch.tensor(hf_processor.image_token_id),
@@ -370,7 +370,7 @@ class InternS1MultiModalProcessor(
video_placeholder, 1)
num_frames = [len(item) for item in video_pixel_values]
video_outputs: dict[str, NestedTensors] = {
video_outputs = {
"pixel_values_videos": torch.concat(video_pixel_values),
"video_num_patches": torch.tensor(num_frames),
"video_token_id": torch.tensor(video_token_id),
@@ -382,16 +382,11 @@ class InternS1MultiModalProcessor(
prompt)
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")
combined_outputs = dict(
**text_outputs,
**image_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
return BatchFeature({**text_outputs, **image_outputs, **video_outputs})
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
@@ -487,6 +482,7 @@ class InternS1MultiModalProcessor(
dummy_inputs=InternS1DummyInputsBuilder)
class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA):
merge_by_field_config = True
# To ensure correct weight loading and mapping.
hf_to_vllm_mapper = WeightsMapper(
@@ -561,7 +557,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
prefix=prefix,
)
def _init_mlp1(self, config: PretrainedConfig) -> nn.Sequential:
def _init_mlp1(self, config: PretrainedConfig) -> nn.Module:
return InternS1MultiModalProjector(config)
def pixel_shuffle(self, x, scale_factor=0.5):
@@ -599,13 +595,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if image_embeds is not None:
if not isinstance(image_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return InternS1ImageEmbeddingInputs(
type="image_embeds",
data=flatten_bn(image_embeds),
data=image_embeds,
)
image_token_id = kwargs["image_token_id"]
@@ -613,17 +605,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.img_context_token_id = image_token_id.flatten().unique().item()
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(image_num_patches)}")
pixel_values = flatten_bn(pixel_values, concat=True)
image_num_patches = flatten_bn(image_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1ImagePixelInputs(
type="pixel_values",
@@ -638,7 +619,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_video_input(
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
self, **kwargs: object) -> Optional[InternS1VideoInputs]:
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
video_num_patches = kwargs.pop("video_num_patches", None)
video_embeds = kwargs.pop("video_embeds", None)
@@ -647,13 +628,9 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
return None
if video_embeds is not None:
if not isinstance(video_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of video embeddings. "
f"Got type: {type(video_embeds)}")
return InternS1ImageEmbeddingInputs(
return InternS1VideoEmbeddingInputs(
type="video_embeds",
data=flatten_bn(video_embeds),
data=video_embeds,
)
video_token_id = kwargs["video_token_id"]
@@ -661,18 +638,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
self.video_context_token_id = video_token_id.flatten().unique().item()
if pixel_values_flat_video is not None:
if not isinstance(pixel_values_flat_video, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values_flat_video)}")
if not isinstance(video_num_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image_num_patches. "
f"Got type: {type(video_num_patches)}")
pixel_values_flat_video = flatten_bn(pixel_values_flat_video,
concat=True)
video_num_patches = flatten_bn(video_num_patches, concat=True)
h, w = self.config.vision_config.image_size
return InternS1VideoPixelInputs(
type="pixel_values_videos",
@@ -686,11 +651,12 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _process_image_input(
def _process_vision_input(
self,
image_input: Union[InternS1ImageInputs, InternS1VideoPixelInputs],
image_input: Union[InternS1ImageInputs, InternS1VideoInputs],
) -> tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
if (image_input["type"] == "image_embeds"
or image_input["type"] == "video_embeds"):
return image_input["data"]
assert self.vision_tower is not None
@@ -753,11 +719,11 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
for modality in modalities:
if modality == "images":
image_input = modalities["images"]
vision_embeddings = self._process_image_input(image_input)
vision_embeddings = self._process_vision_input(image_input)
multimodal_embeddings += vision_embeddings
if modality == "videos":
video_input = modalities["videos"]
video_embeddings = self._process_image_input(video_input)
video_embeddings = self._process_vision_input(video_input)
multimodal_embeddings += video_embeddings
return multimodal_embeddings