[Model] Use merge_by_field_config for MM models (G) (#26117)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-03 13:38:29 +08:00
committed by GitHub
parent 711f485643
commit 39b643dc1a
5 changed files with 56 additions and 108 deletions

View File

@@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix)
logger = init_logger(__name__)
@@ -289,7 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
return processed_outputs
@@ -298,12 +298,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0))
num_patches = hf_inputs.get("num_patches", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
"image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -460,6 +460,8 @@ class Gemma3MultiModalProjector(nn.Module):
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -526,29 +528,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=flatten_bn(pixel_values, concat=True),
num_patches=flatten_bn(num_crops, concat=True) + 1,
resolve_bindings={
"h": image_size,
"w": image_size
})
return Gemma3ImagePixelInputs(pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={
"h": image_size,
"w": image_size
})
def _image_pixels_to_features(
self,