[Model] Add Gemma3 GGUF multimodal support (#27772)

Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Luciano Martins
2025-11-18 13:56:29 -03:00
committed by GitHub
parent 49a986ecd4
commit c2612371ad
14 changed files with 752 additions and 86 deletions

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias
from typing import Annotated, Any, Literal
import torch
from torch import nn
@@ -20,12 +20,7 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
ImageEmbeddingItems,
ImageProcessorItems,
ImageSize,
MultiModalDataItems,
)
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
@@ -76,15 +71,7 @@ class Gemma3ImagePixelInputs(TensorSchema):
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
class Gemma3ImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] = "image_embeds"
image_embeds: Annotated[
torch.Tensor,
TensorShape("ni", "nf", "hs"),
]
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
@@ -191,9 +178,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_image_repl(
self,
*,
image_width: int | None,
image_height: int | None,
num_crops: int | None = None,
image_width: int,
image_height: int,
processor: Gemma3Processor | None,
) -> PromptUpdateDetails[str]:
if processor is None:
@@ -201,13 +187,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
boi_token = processor.boi_token
if num_crops is None:
assert image_width is not None and image_height is not None
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = boi_token
@@ -337,7 +321,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -350,19 +333,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
images = mm_items.get_items(
"image", (ImageEmbeddingItems, ImageProcessorItems)
)
if isinstance(images, ImageEmbeddingItems):
# For image embedding inputs, only support no crops cases
# since it's not supported in hf processor anyway
return self.info.get_image_repl(
image_width=None,
image_height=None,
num_crops=0,
processor=hf_processor,
)
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
@@ -586,19 +557,17 @@ class Gemma3ForConditionalGeneration(
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
if pixel_values is not None:
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
elif image_embeds is not None:
return Gemma3ImageEmbeddingInputs(
image_embeds=image_embeds,
type="image_embeds",
)
image_size = self.config.vision_config.image_size
return Gemma3ImagePixelInputs(
pixel_values=pixel_values,
num_patches=num_patches,
resolve_bindings={"h": image_size, "w": image_size},
)
def _image_pixels_to_features(
self,
@@ -610,9 +579,7 @@ class Gemma3ForConditionalGeneration(
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> torch.Tensor | list[torch.Tensor]:
if image_input["type"] == "image_embeds":
return image_input["image_embeds"]
) -> list[torch.Tensor]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
@@ -629,13 +596,33 @@ class Gemma3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
return self._process_image_input(image_input)
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
# Early return for text-only inference (no multimodal data)
if multimodal_embeddings is None or is_multimodal is None:
return super().embed_input_ids(input_ids)
# Use interface default with OOV handling enabled
return super().embed_input_ids(
input_ids,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)
def forward(
self,
input_ids: torch.Tensor,
@@ -657,6 +644,79 @@ class Gemma3ForConditionalGeneration(
return hidden_states
def generate_attention_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
) -> dict[str, Any]:
"""Generate custom attention masks for Gemma3 multimodal inputs.
This is called by V1 engine's gpu_model_runner during preprocessing
to generate attention masks that allow bidirectional attention between
image tokens while maintaining causal attention for text.
"""
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i]
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
seq_lens.append(end_idx - start_idx)
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_idx, seq_len in enumerate(seq_lens):
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
# Find image token positions
img_pos = input_token_ids == self.config.image_token_index
start_idx = end_idx
# Create a global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0 (causal attention)
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Enable bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# GGUF compatibility: config might be Gemma3TextConfig directly
text_config = getattr(self.config, "text_config", self.config)
sliding_window = text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024)
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
return {
"has_images": True,
"seq_lens": seq_lens,
"global_attn_masks": global_attn_masks,
"local_attn_masks": local_attn_masks,
}
def prepare_attn_masks(
self,
input_ids: torch.Tensor,