[Model] Use merge_by_field_config for MM models (G) (#26117)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, TypedDict, Union, cast
|
||||
from typing import Annotated, Any, Literal, Optional, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -41,6 +41,7 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription)
|
||||
@@ -54,17 +55,28 @@ TOKENS_PER_IMAGE = 256
|
||||
TOKENS_PER_AUDIO = 188
|
||||
|
||||
|
||||
class Gemma3nImagePixelInputs(TypedDict):
|
||||
pixel_values: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||
class Gemma3nImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each patch
|
||||
- w: Width of each patch
|
||||
"""
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
|
||||
|
||||
|
||||
class Gemma3nAudioInputs(TypedDict):
|
||||
input_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
input_features_padded: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
||||
input_features_mask: torch.Tensor
|
||||
"""Shape: `(batch_size * num_audio, seq_length)`"""
|
||||
class Gemma3nAudioInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- s: seq_length
|
||||
- f: num_features
|
||||
"""
|
||||
type: Literal["audio"] = "audio"
|
||||
input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
|
||||
input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]
|
||||
|
||||
|
||||
Gemma3nImageInputs = Gemma3nImagePixelInputs
|
||||
@@ -212,9 +224,9 @@ class Gemma3nMultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
input_features=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_padded=MultiModalFieldConfig.batched("audio"),
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"))
|
||||
input_features_mask=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -422,6 +434,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsTranscription):
|
||||
merge_by_field_config = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
packed_modules_mapping = {
|
||||
@@ -482,14 +495,6 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
device=self.language_model.model.embed_tokens.weight.device,
|
||||
dtype=self.language_model.model.embed_tokens.weight.dtype)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
|
||||
# TODO check if there are any
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
@@ -499,34 +504,22 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
pixel_values = pixel_values.contiguous()
|
||||
|
||||
return Gemma3nImagePixelInputs(
|
||||
pixel_values=self._validate_pixel_values(pixel_values), )
|
||||
return Gemma3nImagePixelInputs(pixel_values=pixel_values)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3nAudioInputs]:
|
||||
input_features = kwargs.pop("input_features", None)
|
||||
if input_features is None:
|
||||
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
|
||||
input_features_mask = kwargs.pop("input_features_mask", None)
|
||||
if input_features_mask is None:
|
||||
return None
|
||||
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
|
||||
return Gemma3nAudioInputs(
|
||||
input_features=input_features,
|
||||
input_features_mask=input_features_mask,
|
||||
input_features_padded=input_features_padded,
|
||||
input_features_mask=input_features_mask,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
@@ -539,7 +532,7 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
) and "image" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"image"] = self._parse_and_validate_image_input(**kwargs)
|
||||
if input_key == "input_features" \
|
||||
if input_key == "input_features_padded" \
|
||||
and "audio" not in mm_input_by_modality:
|
||||
mm_input_by_modality[
|
||||
"audio"] = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user