[Model] Add Gemma3 GGUF multimodal support (#27772)
Signed-off-by: Luciano Martins <lucianommartins@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Luciano Martins <lucianommartins@users.noreply.github.com> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -20,12 +20,7 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
@@ -76,15 +71,7 @@ class Gemma3ImagePixelInputs(TensorSchema):
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
class Gemma3ImageEmbeddingInputs(TensorSchema):
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", "nf", "hs"),
|
||||
]
|
||||
|
||||
|
||||
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
|
||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
@@ -191,9 +178,8 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
def get_image_repl(
|
||||
self,
|
||||
*,
|
||||
image_width: int | None,
|
||||
image_height: int | None,
|
||||
num_crops: int | None = None,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Gemma3Processor | None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
if processor is None:
|
||||
@@ -201,13 +187,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
boi_token = processor.boi_token
|
||||
|
||||
if num_crops is None:
|
||||
assert image_width is not None and image_height is not None
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if num_crops == 0:
|
||||
image_text = boi_token
|
||||
@@ -337,7 +321,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -350,19 +333,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
image_token = hf_processor.boi_token
|
||||
|
||||
def get_replacement_gemma3(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
# For image embedding inputs, only support no crops cases
|
||||
# since it's not supported in hf processor anyway
|
||||
return self.info.get_image_repl(
|
||||
image_width=None,
|
||||
image_height=None,
|
||||
num_crops=0,
|
||||
processor=hf_processor,
|
||||
)
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
return self.info.get_image_repl(
|
||||
@@ -586,19 +557,17 @@ class Gemma3ForConditionalGeneration(
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_patches = kwargs.pop("num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None:
|
||||
image_size = self.config.vision_config.image_size
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={"h": image_size, "w": image_size},
|
||||
)
|
||||
elif image_embeds is not None:
|
||||
return Gemma3ImageEmbeddingInputs(
|
||||
image_embeds=image_embeds,
|
||||
type="image_embeds",
|
||||
)
|
||||
image_size = self.config.vision_config.image_size
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={"h": image_size, "w": image_size},
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
@@ -610,9 +579,7 @@ class Gemma3ForConditionalGeneration(
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma3ImageInputs,
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["image_embeds"]
|
||||
) -> list[torch.Tensor]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
@@ -629,13 +596,33 @@ class Gemma3ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
|
||||
return self._process_image_input(image_input)
|
||||
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# Early return for text-only inference (no multimodal data)
|
||||
if multimodal_embeddings is None or is_multimodal is None:
|
||||
return super().embed_input_ids(input_ids)
|
||||
|
||||
# Use interface default with OOV handling enabled
|
||||
return super().embed_input_ids(
|
||||
input_ids,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
is_multimodal=is_multimodal,
|
||||
handle_oov_mm_token=handle_oov_mm_token,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -657,6 +644,79 @@ class Gemma3ForConditionalGeneration(
|
||||
|
||||
return hidden_states
|
||||
|
||||
def generate_attention_masks(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mask_dtype: torch.dtype,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate custom attention masks for Gemma3 multimodal inputs.
|
||||
|
||||
This is called by V1 engine's gpu_model_runner during preprocessing
|
||||
to generate attention masks that allow bidirectional attention between
|
||||
image tokens while maintaining causal attention for text.
|
||||
"""
|
||||
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
||||
# This is a HACK. Fix this.
|
||||
start_indices = (positions == 0).cpu().nonzero()
|
||||
num_seqs = len(start_indices)
|
||||
seq_lens = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = start_indices[i]
|
||||
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
|
||||
seq_lens.append(end_idx - start_idx)
|
||||
|
||||
global_attn_masks = []
|
||||
local_attn_masks = []
|
||||
start_idx = 0
|
||||
for seq_idx, seq_len in enumerate(seq_lens):
|
||||
end_idx = start_idx + seq_len
|
||||
input_token_ids = input_ids[start_idx:end_idx]
|
||||
|
||||
# Find image token positions
|
||||
img_pos = input_token_ids == self.config.image_token_index
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
# Create a global causal mask
|
||||
global_attn_mask = torch.empty(
|
||||
1,
|
||||
1,
|
||||
seq_len,
|
||||
seq_len,
|
||||
dtype=mask_dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
global_attn_mask.fill_(float("-inf"))
|
||||
# Fill the lower triangle with 0 (causal attention)
|
||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
||||
|
||||
# Enable bidirectional attention between image tokens
|
||||
img_mask = torch.zeros_like(global_attn_mask)
|
||||
img_mask[:, :, :, img_pos] += 1
|
||||
img_mask[:, :, img_pos, :] += 1
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
# GGUF compatibility: config might be Gemma3TextConfig directly
|
||||
text_config = getattr(self.config, "text_config", self.config)
|
||||
sliding_window = text_config.sliding_window
|
||||
if sliding_window is not None:
|
||||
# Create a local causal mask with sliding window (1024)
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
||||
local_attn_mask = torch.where(
|
||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
||||
)
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
|
||||
return {
|
||||
"has_images": True,
|
||||
"seq_lens": seq_lens,
|
||||
"global_attn_masks": global_attn_masks,
|
||||
"local_attn_masks": local_attn_masks,
|
||||
}
|
||||
|
||||
def prepare_attn_masks(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user