[Bugfix] Re-enable Gemma3 for V1 (#14980)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,34 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import math
|
||||
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
|
||||
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig
|
||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
||||
MultiModalDataItems)
|
||||
# yapf: disable
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, encode_tokens)
|
||||
BaseProcessingInfo, BoundPromptUpdate,
|
||||
PlaceholderFeaturesInfo,
|
||||
PromptReplacement, PromptTargetMatch,
|
||||
PromptUpdate, PromptUpdateDetails,
|
||||
encode_tokens, find_mm_placeholders,
|
||||
replace_token_matches)
|
||||
# yapf: enable
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
||||
SupportsMultiModal, SupportsPP, SupportsV0Only)
|
||||
SupportsMultiModal, SupportsPP)
|
||||
from .siglip import SiglipVisionModel
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
from .vision import scatter_patch_features, select_patch_features
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -37,13 +46,25 @@ class Gemma3ImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: torch.Tensor
|
||||
"""
|
||||
Shape: `(num_crops_total, num_channels, height, width)`
|
||||
Shape: `(num_patches_total, num_channels, height, width)`
|
||||
|
||||
`num_crops_total` is the total number of crops
|
||||
`num_patches_total` is the total number of patches
|
||||
over each image over each prompt in the batch.
|
||||
"""
|
||||
num_crops: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images,)`"""
|
||||
|
||||
num_patches: torch.Tensor
|
||||
"""Shape: `(batch_size * num_images)`"""
|
||||
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which image embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size, num_images, num_embeds)`
|
||||
"""
|
||||
|
||||
num_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size, num_images)`"""
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
@@ -51,6 +72,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
|
||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(Gemma3Config)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
|
||||
|
||||
@@ -114,6 +138,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
if not do_pan_and_scan:
|
||||
return 0
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
logger.warning_once(
|
||||
"`do_pan_and_scan=True` has suboptimal results on V1 "
|
||||
"because of the simplified attention pattern being used.")
|
||||
|
||||
# Based on Gemma3ImageProcessor.pan_and_scan
|
||||
if image_width >= image_height:
|
||||
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
|
||||
@@ -154,7 +183,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3Processor],
|
||||
) -> str:
|
||||
) -> PromptUpdateDetails:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
@@ -175,7 +204,11 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
f"Here is the original image {image_token} and here are some "
|
||||
f"crops to help you see better {crops_image_tokens}")
|
||||
|
||||
return image_text.replace(image_token, processor.full_image_sequence)
|
||||
repl_full = image_text.replace(image_token,
|
||||
processor.full_image_sequence)
|
||||
repl_features = repl_full.strip("\n")
|
||||
|
||||
return PromptUpdateDetails(full=repl_full, features=repl_features)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
@@ -193,7 +226,7 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
image_repl_tokens = encode_tokens(
|
||||
tokenizer,
|
||||
image_repl,
|
||||
image_repl.features,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
return len(image_repl_tokens)
|
||||
@@ -240,12 +273,8 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
# NOTE: We need to separate the image tokens here because
|
||||
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
|
||||
# with the detection of prompt updates when the image tokens are
|
||||
# right next to each other
|
||||
return ProcessorInputs(
|
||||
prompt_text=" ".join([image_token] * num_images),
|
||||
prompt_text=image_token * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
@@ -278,13 +307,39 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
]
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
image_repl_features = [
|
||||
self.info.get_image_repl(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor).features
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
image_repls_feature_tokens = [
|
||||
tokenizer.encode(image_repl, add_special_tokens=False)
|
||||
for image_repl in image_repl_features
|
||||
]
|
||||
num_embeds = [
|
||||
len(image_repl_feature_tokens)
|
||||
for image_repl_feature_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["num_embeds"] = torch.tensor(num_embeds)
|
||||
|
||||
vocab = tokenizer.get_vocab()
|
||||
image_token_id = vocab[tokenizer.image_token]
|
||||
|
||||
embed_is_patch = [
|
||||
torch.tensor(image_repl_tokens) == image_token_id
|
||||
for image_repl_tokens in image_repls_feature_tokens
|
||||
]
|
||||
processed_outputs["embed_is_patch"] = embed_is_patch
|
||||
|
||||
num_crops = [
|
||||
self.info.get_num_crops(image_width=size.width,
|
||||
image_height=size.height,
|
||||
processor=hf_processor)
|
||||
for size in image_sizes
|
||||
]
|
||||
|
||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
||||
|
||||
return processed_outputs
|
||||
@@ -300,6 +355,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", num_crops + 1),
|
||||
num_crops=MultiModalFieldConfig.batched("image"),
|
||||
embed_is_patch=MultiModalFieldConfig.batched("image"),
|
||||
num_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -329,6 +386,91 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
)
|
||||
]
|
||||
|
||||
def _apply_token_matches(
|
||||
self,
|
||||
prompt: list[int],
|
||||
mm_matches: Mapping[str, Sequence[PromptTargetMatch]],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> list[int]:
|
||||
token_ids = super()._apply_token_matches(
|
||||
prompt,
|
||||
mm_matches,
|
||||
mm_item_counts,
|
||||
)
|
||||
|
||||
# "\n\n\n" and "\n\n\n\n" are single tokens
|
||||
# Since our replacement can insert "\n\n" next to "\n"
|
||||
# tokens, we have to combine them to be consistent with
|
||||
# the output of the tokenizer
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
newline_1 = vocab["\n"]
|
||||
newline_2 = vocab["\n\n"]
|
||||
newline_3 = vocab["\n\n\n"]
|
||||
newline_4 = vocab["\n\n\n\n"]
|
||||
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_1, newline_2],
|
||||
[newline_3],
|
||||
)
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_2, newline_1],
|
||||
[newline_3],
|
||||
)
|
||||
token_ids = replace_token_matches(
|
||||
token_ids,
|
||||
[newline_2, newline_2],
|
||||
[newline_4],
|
||||
)
|
||||
|
||||
return token_ids
|
||||
|
||||
def _find_mm_placeholders(
|
||||
self,
|
||||
mm_prompt_updates: Mapping[str, Sequence[BoundPromptUpdate]],
|
||||
new_token_ids: list[int],
|
||||
mm_item_counts: Mapping[str, int],
|
||||
) -> Mapping[str, list[PlaceholderFeaturesInfo]]:
|
||||
# We need to detect "\n\n" inside "\n\n\n" and "\n\n\n\n"
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
newline_1 = vocab["\n"]
|
||||
newline_2 = vocab["\n\n"]
|
||||
newline_3 = vocab["\n\n\n"]
|
||||
newline_4 = vocab["\n\n\n\n"]
|
||||
|
||||
def get_repl_toks(tok: int) -> list[int]:
|
||||
if tok == newline_3:
|
||||
return [newline_1, newline_2]
|
||||
if tok == newline_4:
|
||||
return [newline_2, newline_2]
|
||||
|
||||
return [tok]
|
||||
|
||||
repl_token_ids = list[int]()
|
||||
repl_orig_idxs = list[int]()
|
||||
for orig_idx, orig_tok in enumerate(new_token_ids):
|
||||
repl_toks = get_repl_toks(orig_tok)
|
||||
repl_token_ids.extend(repl_toks)
|
||||
repl_orig_idxs.extend(orig_idx for _ in range(len(repl_toks)))
|
||||
|
||||
repls = find_mm_placeholders(mm_prompt_updates, repl_token_ids,
|
||||
mm_item_counts)
|
||||
|
||||
return {
|
||||
modality: [
|
||||
PlaceholderFeaturesInfo(
|
||||
modality=p.modality,
|
||||
item_idx=p.item_idx,
|
||||
start_idx=repl_orig_idxs[p.start_idx],
|
||||
tokens=p.tokens,
|
||||
) for p in placeholders
|
||||
]
|
||||
for modality, placeholders in repls.items()
|
||||
}
|
||||
|
||||
|
||||
class Gemma3MultiModalProjector(nn.Module):
|
||||
|
||||
@@ -374,7 +516,7 @@ class Gemma3MultiModalProjector(nn.Module):
|
||||
info=Gemma3ProcessingInfo,
|
||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
||||
class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
SupportsLoRA, SupportsV0Only):
|
||||
SupportsLoRA):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -415,6 +557,10 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return next(self.parameters()).dtype
|
||||
|
||||
@property
|
||||
def sampler(self):
|
||||
return self.language_model.sampler
|
||||
@@ -438,6 +584,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_crops = kwargs.pop("num_crops", None)
|
||||
embed_is_patch = kwargs.pop("embed_is_patch", None)
|
||||
num_embeds = kwargs.pop("num_embeds", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
@@ -448,16 +596,26 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(num_crops, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_crops values. "
|
||||
raise ValueError("Incorrect type of num_crops. "
|
||||
f"Got type: {type(num_crops)}")
|
||||
|
||||
if not isinstance(embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of embed_is_patch. "
|
||||
f"Got type: {type(embed_is_patch)}")
|
||||
|
||||
if not isinstance(num_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of num_embeds. "
|
||||
f"Got type: {type(num_embeds)}")
|
||||
|
||||
pixel_values = flatten_bn(pixel_values, concat=True)
|
||||
num_crops = flatten_bn(num_crops, concat=True)
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
type="pixel_values",
|
||||
pixel_values=self._validate_pixel_values(pixel_values),
|
||||
num_crops=num_crops,
|
||||
num_patches=num_crops + 1,
|
||||
embed_is_patch=embed_is_patch,
|
||||
num_embeds=num_embeds,
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
@@ -472,36 +630,51 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma3ImageInputs,
|
||||
) -> torch.Tensor:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
vision_outputs = self._image_pixels_to_features(
|
||||
num_patches = image_input["num_patches"]
|
||||
|
||||
image_features = self._image_pixels_to_features(
|
||||
self.vision_tower,
|
||||
pixel_values,
|
||||
)
|
||||
return self.multi_modal_projector(vision_outputs)
|
||||
image_embeds = self.multi_modal_projector(image_features)
|
||||
|
||||
return image_embeds.split(num_patches.tolist())
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
|
||||
image_features = self._process_image_input(image_input)
|
||||
|
||||
if kwargs.get("v0_path", False):
|
||||
return image_features
|
||||
|
||||
return flatten_2d_lists(
|
||||
scatter_patch_features(*args) for args in zip(
|
||||
image_features,
|
||||
image_input["num_embeds"],
|
||||
image_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is None:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
else:
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.config.image_token_index)
|
||||
input_ids,
|
||||
inputs_embeds,
|
||||
select_patch_features(multimodal_embeddings),
|
||||
self.config.image_token_index,
|
||||
)
|
||||
return inputs_embeds
|
||||
|
||||
def forward(self,
|
||||
@@ -516,6 +689,7 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
# NOTE: In v1, inputs_embeds is always generated at model runner, this
|
||||
# condition is for v0 compatibility.
|
||||
elif inputs_embeds is None:
|
||||
kwargs.update({"v0_path": True})
|
||||
vision_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
|
||||
inputs_embeds = self.get_input_embeddings(input_ids,
|
||||
@@ -524,8 +698,9 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
kwargs = self.prepare_attn_masks(
|
||||
input_ids,
|
||||
positions,
|
||||
mask_dtype=vision_embeddings.dtype,
|
||||
**kwargs)
|
||||
mask_dtype=self.dtype,
|
||||
**kwargs,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
|
||||
Reference in New Issue
Block a user