[Model] Use merge_by_field_config for MM models (G) (#26117)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-03 13:38:29 +08:00
committed by GitHub
parent 711f485643
commit 39b643dc1a
5 changed files with 56 additions and 108 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict, Union, cast
from typing import Annotated, Any, Literal, Optional, Union, cast
import numpy as np
import torch
@@ -41,6 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription)
@@ -54,17 +55,28 @@ TOKENS_PER_IMAGE = 256
TOKENS_PER_AUDIO = 188
class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nImagePixelInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- c: Number of channels (3)
- h: Height of each patch
- w: Width of each patch
"""
type: Literal["pixel_values"] = "pixel_values"
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
class Gemma3nAudioInputs(TypedDict):
input_features: Union[torch.Tensor, list[torch.Tensor]]
input_features_padded: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
input_features_mask: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length)`"""
class Gemma3nAudioInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- s: seq_length
- f: num_features
"""
type: Literal["audio"] = "audio"
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
Gemma3nImageInputs = Gemma3nImagePixelInputs
@@ -212,9 +224,9 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_padded=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"))
input_features_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
self,
@@ -422,6 +434,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsTranscription):
merge_by_field_config = True
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
@@ -482,14 +495,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
device=self.language_model.model.embed_tokens.weight.device,
dtype=self.language_model.model.embed_tokens.weight.dtype)
@property
def dtype(self):
return next(self.parameters()).dtype
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
# TODO check if there are any
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
@@ -499,34 +504,22 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
pixel_values = flatten_bn(pixel_values, concat=True)
pixel_values = pixel_values.contiguous()
return Gemma3nImagePixelInputs(
pixel_values=self._validate_pixel_values(pixel_values), )
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
input_features = kwargs.pop("input_features", None)
if input_features is None:
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
input_features_mask = kwargs.pop("input_features_mask", None)
if input_features_mask is None:
return None
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
return Gemma3nAudioInputs(
input_features=input_features,
input_features_mask=input_features_mask,
input_features_padded=input_features_padded,
input_features_mask=input_features_mask,
)
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
@@ -539,7 +532,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
) and "image" not in mm_input_by_modality:
mm_input_by_modality[
"image"] = self._parse_and_validate_image_input(**kwargs)
if input_key == "input_features" \
if input_key == "input_features_padded" \
and "audio" not in mm_input_by_modality:
mm_input_by_modality[
"audio"] = self._parse_and_validate_audio_input(**kwargs)