diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md
index 9cdf644c3..6eb0947fe 100644
--- a/docs/models/supported_models.md
+++ b/docs/models/supported_models.md
@@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen
| `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I+ | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ |
| `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ |
| `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ |
-| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
+| `Gemma3ForConditionalGeneration` | Gemma 3 | T + IE+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ |
| `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | |
| `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ |
| `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ |
diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py
index 02fb7ef31..8e2bbe8f7 100644
--- a/vllm/model_executor/models/gemma3_mm.py
+++ b/vllm/model_executor/models/gemma3_mm.py
@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
-from typing import Annotated, Any, Literal
+from typing import Annotated, Any, Literal, TypeAlias
import torch
from torch import nn
@@ -20,7 +20,12 @@ from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
)
-from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems
+from vllm.multimodal.parse import (
+ ImageEmbeddingItems,
+ ImageProcessorItems,
+ ImageSize,
+ MultiModalDataItems,
+)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
@@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema):
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
-Gemma3ImageInputs = Gemma3ImagePixelInputs
+class Gemma3ImageEmbeddingInputs(TensorSchema):
+ type: Literal["image_embeds"] = "image_embeds"
+ image_embeds: Annotated[
+ torch.Tensor,
+ TensorShape("ni", "nf", "hs"),
+ ]
+
+
+Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
@@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
def get_image_repl(
self,
*,
- image_width: int,
- image_height: int,
+ image_width: int | None,
+ image_height: int | None,
+ num_crops: int | None = None,
processor: Gemma3Processor | None,
) -> PromptUpdateDetails[str]:
if processor is None:
@@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
boi_token = processor.boi_token
- num_crops = self.get_num_crops(
- image_width=image_width,
- image_height=image_height,
- processor=processor,
- )
+ if num_crops is None:
+ assert image_width is not None and image_height is not None
+ num_crops = self.get_num_crops(
+ image_width=image_width,
+ image_height=image_height,
+ processor=processor,
+ )
if num_crops == 0:
image_text = boi_token
@@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches),
num_patches=MultiModalFieldConfig.batched("image"),
+ image_embeds=MultiModalFieldConfig.batched("image"),
)
def _get_prompt_updates(
@@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
image_token = hf_processor.boi_token
def get_replacement_gemma3(item_idx: int):
- images = mm_items.get_items("image", ImageProcessorItems)
+ images = mm_items.get_items(
+ "image", (ImageEmbeddingItems, ImageProcessorItems)
+ )
+
+ if isinstance(images, ImageEmbeddingItems):
+ # For image embedding inputs, only support no crops cases
+ # since it's not supported in hf processor anyway
+ return self.info.get_image_repl(
+ image_width=None,
+ image_height=None,
+ num_crops=0,
+ processor=hf_processor,
+ )
image_size = images.get_image_size(item_idx)
return self.info.get_image_repl(
@@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration(
pixel_values = kwargs.pop("pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
- assert image_embeds is None, "Gemma3 does not support image_embeds."
- if pixel_values is None:
- return None
- image_size = self.config.vision_config.image_size
-
- return Gemma3ImagePixelInputs(
- pixel_values=pixel_values,
- num_patches=num_patches,
- resolve_bindings={"h": image_size, "w": image_size},
- )
+ if pixel_values is not None:
+ image_size = self.config.vision_config.image_size
+ return Gemma3ImagePixelInputs(
+ pixel_values=pixel_values,
+ num_patches=num_patches,
+ resolve_bindings={"h": image_size, "w": image_size},
+ )
+ elif image_embeds is not None:
+ return Gemma3ImageEmbeddingInputs(
+ image_embeds=image_embeds,
+ type="image_embeds",
+ )
def _image_pixels_to_features(
self,
@@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration(
def _process_image_input(
self,
image_input: Gemma3ImageInputs,
- ) -> list[torch.Tensor]:
+ ) -> torch.Tensor | list[torch.Tensor]:
+ if image_input["type"] == "image_embeds":
+ return image_input["image_embeds"]
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py
index 2fa3f6ebc..810f29072 100644
--- a/vllm/multimodal/parse.py
+++ b/vllm/multimodal/parse.py
@@ -359,8 +359,9 @@ class MultiModalDataParser:
)
self.video_needs_metadata = video_needs_metadata
- def _is_embeddings(
- self, data: object
+ @classmethod
+ def is_embeddings(
+ cls, data: object
) -> TypeGuard[torch.Tensor | list[torch.Tensor]]:
if isinstance(data, torch.Tensor):
return data.ndim == 3
@@ -420,7 +421,7 @@ class MultiModalDataParser:
):
return None
- if self._is_embeddings(data):
+ if self.is_embeddings(data):
return AudioEmbeddingItems(data)
data_items: list[AudioItem]
@@ -458,7 +459,7 @@ class MultiModalDataParser:
if self._is_empty(data):
return None
- if self._is_embeddings(data):
+ if self.is_embeddings(data):
return ImageEmbeddingItems(data)
if (
@@ -484,7 +485,7 @@ class MultiModalDataParser:
if self._is_empty(data):
return None
- if self._is_embeddings(data):
+ if self.is_embeddings(data):
return VideoEmbeddingItems(data)
data_items: list[VideoItem]
diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py
index 0404f6ff2..fffd075a5 100644
--- a/vllm/v1/engine/processor.py
+++ b/vllm/v1/engine/processor.py
@@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict
+from vllm.multimodal.parse import MultiModalDataParser
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
@@ -340,7 +341,12 @@ class Processor:
mm_uuids: dict[str, list[str | None] | str] = {}
for modality, data in mm_data.items():
- n = len(data) if isinstance(data, list) else 1
+ # Hash each item for embedding inputs.
+ n = (
+ len(data)
+ if isinstance(data, list) or MultiModalDataParser.is_embeddings(data)
+ else 1
+ )
mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)]
return mm_uuids