[Model] Use merge_by_field_config for MM models (G) (#26117)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -36,7 +36,7 @@ from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||
init_vllm_registered_model, maybe_prefix)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -289,7 +289,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
processor=hf_processor)
|
||||
for size in image_sizes
|
||||
]
|
||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
||||
processed_outputs["num_patches"] = torch.tensor(num_crops) + 1
|
||||
|
||||
return processed_outputs
|
||||
|
||||
@@ -298,12 +298,12 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
num_crops = hf_inputs.get("num_crops", torch.empty(0))
|
||||
num_patches = hf_inputs.get("num_patches", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
"image", num_patches),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -460,6 +460,8 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -526,29 +528,20 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
num_patches = kwargs.pop("num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
image_size = self.config.vision_config.image_size
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=flatten_bn(pixel_values, concat=True),
|
||||
num_patches=flatten_bn(num_crops, concat=True) + 1,
|
||||
resolve_bindings={
|
||||
"h": image_size,
|
||||
"w": image_size
|
||||
})
|
||||
return Gemma3ImagePixelInputs(pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={
|
||||
"h": image_size,
|
||||
"w": image_size
|
||||
})
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user