diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 9cdf644c3..6eb0947fe 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -669,7 +669,7 @@ These models primarily accept the [`LLM.generate`](./generative_models.md#llmgen | `DeepseekOCRForCausalLM` | DeepSeek-OCR | T + I+ | `deepseek-ai/DeepSeek-OCR`, etc. | | ✅︎ | | `Ernie4_5_VLMoeForConditionalGeneration` | Ernie4.5-VL | T + I+/ V+ | `baidu/ERNIE-4.5-VL-28B-A3B-PT`, `baidu/ERNIE-4.5-VL-424B-A47B-PT` | | ✅︎ | | `FuyuForCausalLM` | Fuyu | T + I | `adept/fuyu-8b`, etc. | | ✅︎ | -| `Gemma3ForConditionalGeneration` | Gemma 3 | T + I+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | +| `Gemma3ForConditionalGeneration` | Gemma 3 | T + IE+ | `google/gemma-3-4b-it`, `google/gemma-3-27b-it`, etc. | ✅︎ | ✅︎ | | `Gemma3nForConditionalGeneration` | Gemma 3n | T + I + A | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GLM4VForCausalLM`^ | GLM-4V | T + I | `zai-org/glm-4v-9b`, `zai-org/cogagent-9b-20241220`, etc. | ✅︎ | ✅︎ | | `Glm4vForConditionalGeneration` | GLM-4.1V-Thinking | T + IE+ + VE+ | `zai-org/GLM-4.1V-9B-Thinking`, etc. | ✅︎ | ✅︎ | diff --git a/vllm/model_executor/models/gemma3_mm.py b/vllm/model_executor/models/gemma3_mm.py index 02fb7ef31..8e2bbe8f7 100644 --- a/vllm/model_executor/models/gemma3_mm.py +++ b/vllm/model_executor/models/gemma3_mm.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import math from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal +from typing import Annotated, Any, Literal, TypeAlias import torch from torch import nn @@ -20,7 +20,12 @@ from vllm.multimodal.inputs import ( MultiModalFieldConfig, MultiModalKwargsItems, ) -from vllm.multimodal.parse import ImageProcessorItems, ImageSize, MultiModalDataItems +from vllm.multimodal.parse import ( + ImageEmbeddingItems, + ImageProcessorItems, + ImageSize, + MultiModalDataItems, +) from vllm.multimodal.processing import ( BaseMultiModalProcessor, BaseProcessingInfo, @@ -71,7 +76,15 @@ class Gemma3ImagePixelInputs(TensorSchema): num_patches: Annotated[torch.Tensor, TensorShape("bn")] -Gemma3ImageInputs = Gemma3ImagePixelInputs +class Gemma3ImageEmbeddingInputs(TensorSchema): + type: Literal["image_embeds"] = "image_embeds" + image_embeds: Annotated[ + torch.Tensor, + TensorShape("ni", "nf", "hs"), + ] + + +Gemma3ImageInputs: TypeAlias = Gemma3ImagePixelInputs | Gemma3ImageEmbeddingInputs class Gemma3ProcessingInfo(BaseProcessingInfo): @@ -178,8 +191,9 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): def get_image_repl( self, *, - image_width: int, - image_height: int, + image_width: int | None, + image_height: int | None, + num_crops: int | None = None, processor: Gemma3Processor | None, ) -> PromptUpdateDetails[str]: if processor is None: @@ -187,11 +201,13 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): boi_token = processor.boi_token - num_crops = self.get_num_crops( - image_width=image_width, - image_height=image_height, - processor=processor, - ) + if num_crops is None: + assert image_width is not None and image_height is not None + num_crops = self.get_num_crops( + image_width=image_width, + image_height=image_height, + processor=processor, + ) if num_crops == 0: image_text = boi_token @@ -321,6 +337,7 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): return dict( pixel_values=MultiModalFieldConfig.flat_from_sizes("image", num_patches), num_patches=MultiModalFieldConfig.batched("image"), + image_embeds=MultiModalFieldConfig.batched("image"), ) def _get_prompt_updates( @@ -333,7 +350,19 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): image_token = hf_processor.boi_token def get_replacement_gemma3(item_idx: int): - images = mm_items.get_items("image", ImageProcessorItems) + images = mm_items.get_items( + "image", (ImageEmbeddingItems, ImageProcessorItems) + ) + + if isinstance(images, ImageEmbeddingItems): + # For image embedding inputs, only support no crops cases + # since it's not supported in hf processor anyway + return self.info.get_image_repl( + image_width=None, + image_height=None, + num_crops=0, + processor=hf_processor, + ) image_size = images.get_image_size(item_idx) return self.info.get_image_repl( @@ -557,17 +586,19 @@ class Gemma3ForConditionalGeneration( pixel_values = kwargs.pop("pixel_values", None) num_patches = kwargs.pop("num_patches", None) image_embeds = kwargs.pop("image_embeds", None) - assert image_embeds is None, "Gemma3 does not support image_embeds." - if pixel_values is None: - return None - image_size = self.config.vision_config.image_size - - return Gemma3ImagePixelInputs( - pixel_values=pixel_values, - num_patches=num_patches, - resolve_bindings={"h": image_size, "w": image_size}, - ) + if pixel_values is not None: + image_size = self.config.vision_config.image_size + return Gemma3ImagePixelInputs( + pixel_values=pixel_values, + num_patches=num_patches, + resolve_bindings={"h": image_size, "w": image_size}, + ) + elif image_embeds is not None: + return Gemma3ImageEmbeddingInputs( + image_embeds=image_embeds, + type="image_embeds", + ) def _image_pixels_to_features( self, @@ -579,7 +610,9 @@ class Gemma3ForConditionalGeneration( def _process_image_input( self, image_input: Gemma3ImageInputs, - ) -> list[torch.Tensor]: + ) -> torch.Tensor | list[torch.Tensor]: + if image_input["type"] == "image_embeds": + return image_input["image_embeds"] assert self.vision_tower is not None pixel_values = image_input["pixel_values"] diff --git a/vllm/multimodal/parse.py b/vllm/multimodal/parse.py index 2fa3f6ebc..810f29072 100644 --- a/vllm/multimodal/parse.py +++ b/vllm/multimodal/parse.py @@ -359,8 +359,9 @@ class MultiModalDataParser: ) self.video_needs_metadata = video_needs_metadata - def _is_embeddings( - self, data: object + @classmethod + def is_embeddings( + cls, data: object ) -> TypeGuard[torch.Tensor | list[torch.Tensor]]: if isinstance(data, torch.Tensor): return data.ndim == 3 @@ -420,7 +421,7 @@ class MultiModalDataParser: ): return None - if self._is_embeddings(data): + if self.is_embeddings(data): return AudioEmbeddingItems(data) data_items: list[AudioItem] @@ -458,7 +459,7 @@ class MultiModalDataParser: if self._is_empty(data): return None - if self._is_embeddings(data): + if self.is_embeddings(data): return ImageEmbeddingItems(data) if ( @@ -484,7 +485,7 @@ class MultiModalDataParser: if self._is_empty(data): return None - if self._is_embeddings(data): + if self.is_embeddings(data): return VideoEmbeddingItems(data) data_items: list[VideoItem] diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 0404f6ff2..fffd075a5 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -14,6 +14,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.multimodal.cache import processor_cache_from_config from vllm.multimodal.inputs import MultiModalFeatureSpec, MultiModalUUIDDict +from vllm.multimodal.parse import MultiModalDataParser from vllm.multimodal.processing import EncDecMultiModalProcessor from vllm.multimodal.utils import argsort_mm_positions from vllm.pooling_params import PoolingParams @@ -340,7 +341,12 @@ class Processor: mm_uuids: dict[str, list[str | None] | str] = {} for modality, data in mm_data.items(): - n = len(data) if isinstance(data, list) else 1 + # Hash each item for embedding inputs. + n = ( + len(data) + if isinstance(data, list) or MultiModalDataParser.is_embeddings(data) + else 1 + ) mm_uuids[modality] = [f"{request_id}-{modality}-{i}" for i in range(n)] return mm_uuids