[VLM] Support pan-and-scan for Gemma3 multi-modal processor (#14672)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2025-03-13 17:23:12 +08:00
committed by GitHub
parent a73122de96
commit 382403921f
9 changed files with 315 additions and 81 deletions

View File

@@ -1,10 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
import math
from typing import (Any, Iterable, Literal, Mapping, Optional, Sequence, Set,
Tuple, TypedDict, Union)
import torch
from torch import nn
from transformers import BatchFeature, Gemma3Config, ProcessorMixin
from transformers import BatchFeature, Gemma3Config, Gemma3Processor
from transformers.models.gemma3.processing_gemma3 import Gemma3ProcessorKwargs
from vllm.config import VllmConfig
from vllm.logger import init_logger
@@ -14,10 +16,11 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import ImageSize, MultiModalDataItems
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
PromptUpdate, encode_tokens)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -31,8 +34,15 @@ logger = init_logger(__name__)
class Gemma3ImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
pixel_values: torch.Tensor
"""
Shape: `(num_crops_total, num_channels, height, width)`
`num_crops_total` is the total number of crops
over each image over each prompt in the batch.
"""
num_crops: torch.Tensor
"""Shape: `(batch_size * num_images,)`"""
Gemma3ImageInputs = Gemma3ImagePixelInputs
@@ -40,6 +50,9 @@ Gemma3ImageInputs = Gemma3ImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3Processor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@@ -48,22 +61,160 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
hf_config = self.ctx.get_hf_config()
return {"image": hf_config.mm_tokens_per_image}
return {"image": self.get_max_image_tokens()}
def _resolve_image_kwargs(
self,
processor: Gemma3Processor,
keys: set[str],
) -> dict[str, Any]:
image_processor = processor.image_processor
kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
)
images_kwargs = kwargs["images_kwargs"]
def _resolve_kw(key: str):
val = getattr(image_processor, key)
if val is None:
val = images_kwargs[key]
return val
return {k: _resolve_kw(k) for k in keys}
def get_num_crops(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {
"do_pan_and_scan", "pan_and_scan_min_crop_size",
"pan_and_scan_max_num_crops",
"pan_and_scan_min_ratio_to_activate"
})
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs[
"pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs[
"pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"]
if not do_pan_and_scan:
return 0
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_w = min(
int(math.floor(image_width / pan_and_scan_min_crop_size)),
int(math.floor(image_width / image_height + 0.5)),
)
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
else:
if image_height / image_width < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_h = min(
int(math.floor(image_height / pan_and_scan_min_crop_size)),
int(math.floor(image_height / image_width + 0.5)),
)
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(image_width / num_crops_w))
crop_size_h = int(math.ceil(image_height / num_crops_h))
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return 0
return num_crops_w * num_crops_h
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> str:
if processor is None:
processor = self.get_hf_processor()
image_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = image_token
else:
crops_image_tokens = " ".join(image_token
for _ in range(num_crops))
image_text = (
f"Here is the original image {image_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
return image_text.replace(image_token, processor.full_image_sequence)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[ProcessorMixin],
processor: Optional[Gemma3Processor],
) -> int:
hf_config = self.ctx.get_hf_config()
return hf_config.mm_tokens_per_image
tokenizer = self.get_tokenizer()
image_repl = self.get_image_repl(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_repl_tokens = encode_tokens(
tokenizer,
image_repl,
add_special_tokens=False,
)
return len(image_repl_tokens)
def get_image_size_with_most_features(self) -> ImageSize:
# Result in the max possible feature size (h:w = 16:1)
return ImageSize(height=8000, width=50)
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {"pan_and_scan_max_num_crops"})
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
# Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50)
def get_max_image_tokens(self) -> int:
target_width, target_height = self.get_image_size_with_most_features()
return self.get_num_image_tokens(
image_width=target_width,
image_height=target_height,
processor=None,
)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
@@ -73,10 +224,11 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
boi_token = tokenizer.boi_token
processor = self.info.get_hf_processor()
image_token = processor.boi_token
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
@@ -86,8 +238,13 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
height=target_height,
num_images=num_images)
}
# NOTE: We need to separate the image tokens here because
# encode("\n\n\n\n") != encode("\n\n") * 2, which interferes
# with the detection of prompt updates when the image tokens are
# right next to each other
return ProcessorInputs(
prompt_text=" ".join([boi_token] * num_images),
prompt_text=" ".join([image_token] * num_images),
mm_data=mm_data,
)
@@ -100,22 +257,49 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# TODO(woosuk): Support pan-and-scan.
img_kwargs = mm_kwargs.get("images_kwargs", {})
img_kwargs["do_pan_and_scan"] = False
mm_kwargs["images_kwargs"] = img_kwargs
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
processed_outputs = super()._call_hf_processor(
prompt,
mm_data,
mm_kwargs,
)
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
assert isinstance(images, list)
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
image_sizes = [
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
num_crops = hf_inputs.get("num_crops", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
self,
@@ -123,25 +307,23 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, Any],
out_mm_kwargs: MultiModalKwargs,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
hf_config = self.info.get_hf_config()
boi_token = tokenizer.boi_token
image_token = tokenizer.image_token
mm_tokens_per_image = hf_config.mm_tokens_per_image
image_tokens_expanded = "".join([image_token] * mm_tokens_per_image)
image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
return PromptUpdateDetails(
full=hf_processor.full_image_sequence,
features=image_tokens_expanded,
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
image_width=image_size.width,
image_height=image_size.height,
processor=hf_processor,
)
return [
PromptReplacement(
modality="image",
target=boi_token,
target=image_token,
replacement=get_replacement_gemma3,
)
]
@@ -254,19 +436,27 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Gemma3ImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
num_crops = kwargs.pop("num_crops", None)
image_embeds = kwargs.pop("image_embeds", None)
assert image_embeds is None, "Gemma3 does not support image_embeds."
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list[torch.Tensor])):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(num_crops, (torch.Tensor, list)):
raise ValueError("Incorrect type of num_crops values. "
f"Got type: {type(num_crops)}")
pixel_values = flatten_bn(pixel_values, concat=True)
num_crops = flatten_bn(num_crops, concat=True)
return Gemma3ImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
pixel_values=self._validate_pixel_values(pixel_values),
num_crops=num_crops,
)
def _image_pixels_to_features(
@@ -283,7 +473,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal,
image_input: Gemma3ImageInputs,
) -> torch.Tensor:
assert self.vision_tower is not None
pixel_values = image_input["data"]
pixel_values = image_input["pixel_values"]
vision_outputs = self._image_pixels_to_features(
self.vision_tower,
pixel_values,