[Bugfix] Re-enable Gemma3 for V1 (#14980)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-19 14:58:22 +08:00
committed by GitHub
parent 05ccd0aa35
commit 61f412187d
8 changed files with 419 additions and 175 deletions

View File

@@ -1,34 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
Tuple, TypedDict, Union)
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.inputs import MultiModalFieldConfig
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, encode_tokens)
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
encode_tokens, find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP, SupportsV0Only)
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .vision import scatter_patch_features, select_patch_features
logger = init_logger(__name__)
@@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
pixel_values: torch.Tensor
"""
Shape: `(num_crops_total, num_channels, height, width)`
Shape: `(num_patches_total, num_channels, height, width)`
`num_crops_total` is the total number of crops
`num_patches_total` is the total number of patches
over each image over each prompt in the batch.
"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images,)`"""
num_patches: torch.Tensor
"""Shape: `(batch_size * num_images)`"""
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which image embeddings correspond
to patch tokens.
Shape: `(batch_size, num_images, num_embeds)`
"""
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""Shape: `(batch_size, num_images)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3Config)
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
@@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
if not do_pan_and_scan:
return 0
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used.")
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
@@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> str:
) -> PromptUpdateDetails:
if processor is None:
processor = self.get_hf_processor()
@@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
f"Here is the original image {image_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
return image_text.replace(image_token, processor.full_image_sequence)
repl_full = image_text.replace(image_token,
processor.full_image_sequence)
repl_features = repl_full.strip("\n")
return PromptUpdateDetails(full=repl_full, features=repl_features)
def get_num_image_tokens(
self,
@@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
image_repl.features,
add_special_tokens=False,
)
return len(image_repl_tokens)
@@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
num_images=num_images)
}
# NOTE: We need to separate the image tokens here because
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
# with the detection of prompt updates when the image tokens are
# right next to each other
return ProcessorInputs(
prompt_text=" ".join([image_token] * num_images),
prompt_text=image_token * num_images,
mm_data=mm_data,
)
@@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
image_repl_features = [
self.info.get_image_repl(image_width=size.width,
image_height=size.height,
processor=hf_processor).features
for size in image_sizes
]
tokenizer = self.info.get_tokenizer()
image_repls_feature_tokens = [
tokenizer.encode(image_repl, add_special_tokens=False)
for image_repl in image_repl_features
]
num_embeds = [
len(image_repl_feature_tokens)
for image_repl_feature_tokens in image_repls_feature_tokens
]
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
embed_is_patch = [
torch.tensor(image_repl_tokens) == image_token_id
for image_repl_tokens in image_repls_feature_tokens
]
processed_outputs["embed_is_patch"] = embed_is_patch
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
return processed_outputs
@@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
embed_is_patch=MultiModalFieldConfig.batched("image"),
num_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
)
]
def _apply_token_matches(
self,
prompt: list[int],
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
mm_item_counts: Mapping[str, int],
) -> list[int]:
token_ids = super()._apply_token_matches(
prompt,
mm_matches,
mm_item_counts,
)
# "\n\n\n" and "\n\n\n\n" are single tokens
# Since our replacement can insert "\n\n" next to "\n"
# tokens, we have to combine them to be consistent with
# the output of the tokenizer
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
token_ids = replace_token_matches(
token_ids,
[newline_1, newline_2],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_1],
[newline_3],
)
token_ids = replace_token_matches(
token_ids,
[newline_2, newline_2],
[newline_4],
)
return token_ids
def _find_mm_placeholders(
self,
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
new_token_ids: list[int],
mm_item_counts: Mapping[str, int],
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
newline_1 = vocab["\n"]
newline_2 = vocab["\n\n"]
newline_3 = vocab["\n\n\n"]
newline_4 = vocab["\n\n\n\n"]
def get_repl_toks(tok: int) -> list[int]:
if tok == newline_3:
return [newline_1, newline_2]
if tok == newline_4:
return [newline_2, newline_2]
return [tok]
repl_token_ids = list[int]()
repl_orig_idxs = list[int]()
for orig_idx, orig_tok in enumerate(new_token_ids):
repl_toks = get_repl_toks(orig_tok)
repl_token_ids.extend(repl_toks)
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
mm_item_counts)
return {
modality: [
PlaceholderFeaturesInfo(
modality=p.modality,
item_idx=p.item_idx,
start_idx=repl_orig_idxs[p.start_idx],
tokens=p.tokens,
) for p in placeholders
]
for modality, placeholders in repls.items()
}
class Gemma3MultiModalProjector(nn.Module):
@@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
SupportsLoRA, SupportsV0Only):
SupportsLoRA):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@property
def dtype(self):
return next(self.parameters()).dtype
@property
def sampler(self):
return self.language_model.sampler
@@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
embed_is_patch = kwargs.pop("embed_is_patch", None)
num_embeds = kwargs.pop("num_embeds", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
@@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
f"Got type: {type(pixel_values)}")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops values. "
raise ValueError("Incorrect type of num_crops. "
f"Got type: {type(num_crops)}")
if not isinstance(embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of embed_is_patch. "
f"Got type: {type(embed_is_patch)}")
if not isinstance(num_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_embeds. "
f"Got type: {type(num_embeds)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
return Gemma3ImagePixelInputs(
type="pixel_values",
pixel_values=self._validate_pixel_values(pixel_values),
num_crops=num_crops,
num_patches=num_crops + 1,
embed_is_patch=embed_is_patch,
num_embeds=num_embeds,
)
def _image_pixels_to_features(
@@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
) -> torch.Tensor:
) -> tuple[torch.Tensor, ...]:
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
vision_outputs = self._image_pixels_to_features(
num_patches = image_input["num_patches"]
image_features = self._image_pixels_to_features(
self.vision_tower,
pixel_values,
)
return self.multi_modal_projector(vision_outputs)
image_embeds = self.multi_modal_projector(image_features)
return image_embeds.split(num_patches.tolist())
def get_multimodal_embeddings(
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return None
vision_embeddings = self._process_image_input(image_input)
return vision_embeddings
image_features = self._process_image_input(image_input)
if kwargs.get("v0_path", False):
return image_features
return flatten_2d_lists(
scatter_patch_features(*args) for args in zip(
image_features,
image_input["num_embeds"],
image_input["embed_is_patch"],
))
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
) -> torch.Tensor:
if multimodal_embeddings is None:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
else:
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.config.image_token_index)
input_ids,
inputs_embeds,
select_patch_features(multimodal_embeddings),
self.config.image_token_index,
)
return inputs_embeds
def forward(self,
@@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
# NOTE: In v1, inputs_embeds is always generated at model runner, this
# condition is for v0 compatibility.
elif inputs_embeds is None:
kwargs.update({"v0_path": True})
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(input_ids,
@@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
kwargs = self.prepare_attn_masks(
input_ids,
positions,
mask_dtype=vision_embeddings.dtype,
**kwargs)
mask_dtype=self.dtype,
**kwargs,
)
input_ids = None
hidden_states = self.language_model.model(input_ids,