Allow Gemma3 to take image embeddings (#28483)
Signed-off-by: tingtinggithub <streamttt@gmail.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
|
||||
from vllm.multimodal.parse import (
|
||||
ImageEmbeddingItems,
|
||||
ImageProcessorItems,
|
||||
ImageSize,
|
||||
MultiModalDataItems,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
@@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
|
||||
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
|
||||
|
||||
|
||||
Gemma3ImageInputs = Gemma3ImagePixelInputs
|
||||
class Gemma3ImageEmbeddingInputs(TensorSchema):
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
image_embeds: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape("ni", "nf", "hs"),
|
||||
]
|
||||
|
||||
|
||||
Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
|
||||
|
||||
|
||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
@@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
def get_image_repl(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_width: int | None,
|
||||
image_height: int | None,
|
||||
num_crops: int | None = None,
|
||||
processor: Gemma3Processor | None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
if processor is None:
|
||||
@@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
boi_token = processor.boi_token
|
||||
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
if num_crops is None:
|
||||
assert image_width is not None and image_height is not None
|
||||
num_crops = self.get_num_crops(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
if num_crops == 0:
|
||||
image_text = boi_token
|
||||
@@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
return dict(
|
||||
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
|
||||
num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
@@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||
image_token = hf_processor.boi_token
|
||||
|
||||
def get_replacement_gemma3(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
# For image embedding inputs, only support no crops cases
|
||||
# since it's not supported in hf processor anyway
|
||||
return self.info.get_image_repl(
|
||||
image_width=None,
|
||||
image_height=None,
|
||||
num_crops=0,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
return self.info.get_image_repl(
|
||||
@@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
num_patches = kwargs.pop("num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
assert image_embeds is None, "Gemma3 does not support image_embeds."
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
image_size = self.config.vision_config.image_size
|
||||
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={"h": image_size, "w": image_size},
|
||||
)
|
||||
if pixel_values is not None:
|
||||
image_size = self.config.vision_config.image_size
|
||||
return Gemma3ImagePixelInputs(
|
||||
pixel_values=pixel_values,
|
||||
num_patches=num_patches,
|
||||
resolve_bindings={"h": image_size, "w": image_size},
|
||||
)
|
||||
elif image_embeds is not None:
|
||||
return Gemma3ImageEmbeddingInputs(
|
||||
image_embeds=image_embeds,
|
||||
type="image_embeds",
|
||||
)
|
||||
|
||||
def _image_pixels_to_features(
|
||||
self,
|
||||
@@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: Gemma3ImageInputs,
|
||||
) -> list[torch.Tensor]:
|
||||
) -> torch.Tensor | list[torch.Tensor]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["image_embeds"]
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = image_input["pixel_values"]
|
||||
|
||||
Reference in New Issue
Block a user