diff --git a/tests/models/multimodal/processing/test_gemma3.py b/tests/models/multimodal/processing/test_gemma3.py index 32a459ee8..e252be894 100644 --- a/tests/models/multimodal/processing/test_gemma3.py +++ b/tests/models/multimodal/processing/test_gemma3.py @@ -2,14 +2,154 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import pytest +import torch +from vllm.model_executor.models.gemma3n_audio_utils import ( + adjust_audio_features_to_expected_length, +) from vllm.multimodal import MULTIMODAL_REGISTRY from ....conftest import ImageTestAssets from ...utils import build_model_context +# Gemma3 (image) model +GEMMA3_MODEL_ID = "google/gemma-3-4b-it" -@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"]) +# Gemma3n (multimodal with audio) model +GEMMA3N_MODEL_ID = "google/gemma-3n-E2B-it" + +# Expected audio tokens for Gemma3n (audio_soft_tokens_per_image) +GEMMA3N_EXPECTED_AUDIO_TOKENS = 188 + + +class TestGemma3nAudioTensorLogic: + """CPU-based tests for Gemma3n audio feature tensor manipulation. + + These tests validate the padding/truncation logic in + adjust_audio_features_to_expected_length() which fixes the + integer overflow in _process_audio_input when audio_seq_len > 188. + """ + + def test_padding_when_audio_short(self): + """Test that short audio is padded to expected length.""" + batch_size, seq_len, embed_dim = 1, 100, 256 + expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS + + audio_features = torch.randn(batch_size, seq_len, embed_dim) + padding_embs = torch.zeros(1, 1, embed_dim) + + result, tokens_truncated = adjust_audio_features_to_expected_length( + audio_features, expected_tokens, padding_embs + ) + + assert result.shape == (batch_size, expected_tokens, embed_dim) + assert tokens_truncated == 0 + # First 100 tokens should be original, rest should be padding (zeros) + assert torch.allclose(result[:, :seq_len, :], audio_features) + assert torch.allclose( + result[:, seq_len:, :], + torch.zeros(batch_size, expected_tokens - seq_len, embed_dim), + ) + + def test_truncation_when_audio_long(self): + """Test that long audio is truncated to expected length. + + This is the key test for the overflow fix. Previously, when + audio_seq_len > expected_tokens, the code would compute a negative + padding value causing: RuntimeError: numel: integer multiplication overflow + """ + batch_size, seq_len, embed_dim = 1, 192, 256 # 192 > 188 + expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS + + audio_features = torch.randn(batch_size, seq_len, embed_dim) + padding_embs = torch.zeros(1, 1, embed_dim) + + result, tokens_truncated = adjust_audio_features_to_expected_length( + audio_features, expected_tokens, padding_embs + ) + + assert result.shape == (batch_size, expected_tokens, embed_dim) + assert tokens_truncated == seq_len - expected_tokens # 192 - 188 = 4 + # Result should be first 188 tokens of original + assert torch.allclose(result, audio_features[:, :expected_tokens, :]) + + def test_no_change_when_exact_length(self): + """Test that exact-length audio passes through unchanged.""" + batch_size, embed_dim = 1, 256 + expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS + + audio_features = torch.randn(batch_size, expected_tokens, embed_dim) + padding_embs = torch.zeros(1, 1, embed_dim) + + result, tokens_truncated = adjust_audio_features_to_expected_length( + audio_features, expected_tokens, padding_embs + ) + + assert result.shape == audio_features.shape + assert tokens_truncated == 0 + assert torch.allclose(result, audio_features) + + def test_original_bug_would_fail(self): + """Verify the original buggy implementation would cause overflow. + + The original code always tried to pad, which fails when + audio_seq_len > expected_tokens because expand() gets negative size. + """ + batch_size, seq_len, embed_dim = 1, 192, 256 + expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS + + padding_embs = torch.zeros(1, 1, embed_dim) + + # Original buggy logic (always pads, never truncates) + extra_padding_tokens = expected_tokens - seq_len # = -4 (negative!) + + with pytest.raises(RuntimeError): + # This should fail with negative size error + padding_embs.expand(batch_size, extra_padding_tokens, embed_dim) + + @pytest.mark.parametrize( + "seq_len", + [50, 100, 150, 187, 188, 189, 192, 200, 300], + ) + def test_various_audio_lengths(self, seq_len: int): + """Test padding/truncation with various audio lengths.""" + batch_size, embed_dim = 1, 256 + expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS + + audio_features = torch.randn(batch_size, seq_len, embed_dim) + padding_embs = torch.zeros(1, 1, embed_dim) + + # Should not raise any errors + result, tokens_truncated = adjust_audio_features_to_expected_length( + audio_features, expected_tokens, padding_embs + ) + + # Output should always be expected_tokens length + assert result.shape == (batch_size, expected_tokens, embed_dim) + + # Verify truncation count is correct + if seq_len > expected_tokens: + assert tokens_truncated == seq_len - expected_tokens + else: + assert tokens_truncated == 0 + + def test_batch_processing(self): + """Test that batch processing works correctly.""" + batch_size, seq_len, embed_dim = 4, 192, 256 + expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS + + audio_features = torch.randn(batch_size, seq_len, embed_dim) + padding_embs = torch.zeros(1, 1, embed_dim) + + result, tokens_truncated = adjust_audio_features_to_expected_length( + audio_features, expected_tokens, padding_embs + ) + + assert result.shape == (batch_size, expected_tokens, embed_dim) + assert tokens_truncated == seq_len - expected_tokens + + +@pytest.mark.parametrize("model_id", [GEMMA3_MODEL_ID]) def test_get_image_size_with_most_features( image_assets: ImageTestAssets, model_id: str ): diff --git a/vllm/model_executor/models/gemma3n_audio_utils.py b/vllm/model_executor/models/gemma3n_audio_utils.py new file mode 100644 index 000000000..bef9bb9a0 --- /dev/null +++ b/vllm/model_executor/models/gemma3n_audio_utils.py @@ -0,0 +1,57 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Lightweight utility functions for Gemma3n audio processing. + +This module is separate from gemma3n_mm.py to avoid heavy CUDA dependencies, +making it testable without a full vLLM build. +""" + +import torch + + +def adjust_audio_features_to_expected_length( + audio_features: torch.Tensor, + expected_tokens: int, + audio_padding_embs: torch.Tensor, +) -> tuple[torch.Tensor, int]: + """Adjust audio features to expected token length via padding or truncation. + + The Gemma3nProcessor expects all audio will be ~30s in length and inserts + a fixed number of audio soft tokens into the text. However, the audio + preprocessing and encoder do not guarantee they will produce exactly that + many soft tokens; they may produce fewer tokens (for shorter audio) or more + tokens (for longer audio or due to BOA/EOA special tokens). + + This function handles both cases: + - If fewer tokens: pad with the provided padding embeddings + - If more tokens: truncate to the expected count + + Args: + audio_features: Audio embeddings tensor of shape + (batch_size, seq_len, embed_dim) + expected_tokens: The expected number of audio tokens (e.g., 188) + audio_padding_embs: Padding embeddings tensor of shape (1, 1, embed_dim) + + Returns: + Tuple of: + - adjusted_features: Audio features adjusted to expected_tokens length + - tokens_truncated: Number of tokens truncated (0 if padding was applied) + """ + audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape + tokens_truncated = 0 + + if audio_seq_len < expected_tokens: + # Pad to expected length with padding embeddings + extra_padding_tokens = expected_tokens - audio_seq_len + extra_padding_features = audio_padding_embs.expand( + audio_batch_size, extra_padding_tokens, audio_embed_dim + ) + audio_features = torch.cat((audio_features, extra_padding_features), dim=1) + elif audio_seq_len > expected_tokens: + # Truncate to expected length (audio encoder produced more tokens + # than expected, e.g., due to longer audio or placeholder mismatch) + tokens_truncated = audio_seq_len - expected_tokens + audio_features = audio_features[:, :expected_tokens, :] + + return audio_features, tokens_truncated diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 7036118ad..acb0d7399 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -1,11 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, Optional, Union, cast +from typing import Annotated, Any, Literal, cast import numpy as np import torch - from torch import nn from transformers import AutoModel, BatchFeature from transformers.models.gemma3n import ( @@ -26,6 +25,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM +from vllm.model_executor.models.gemma3n_audio_utils import ( + adjust_audio_features_to_expected_length, +) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS from vllm.multimodal import MULTIMODAL_REGISTRY @@ -105,12 +107,12 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): def get_hf_processor(self, **kwargs: object): return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs) - def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: + def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "audio": None} def get_max_tokens_per_item( self, seq_len: int, mm_counts: Mapping[str, int] - ) -> Optional[Mapping[str, int]]: + ) -> Mapping[str, int] | None: return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO} def get_image_repl( @@ -118,7 +120,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): *, image_width: int, image_height: int, - processor: Optional[Gemma3nProcessor], + processor: Gemma3nProcessor | None, ) -> str: """ Get the replacement text for image tokens. @@ -136,7 +138,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo): def get_audio_repl( self, *, - processor: Optional[Gemma3nProcessor], + processor: Gemma3nProcessor | None, ) -> str: """ Get the replacement text for audio tokens. @@ -168,7 +170,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): self, seq_len: int, mm_counts: Mapping[str, int], - mm_options: Optional[Mapping[str, BaseDummyOptions]] = None, + mm_options: Mapping[str, BaseDummyOptions] | None = None, ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) num_audios = mm_counts.get("audio", 0) @@ -387,7 +389,7 @@ class Gemma3nMultimodalEmbedder(nn.Module): def __init__( self, - multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig], + multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig, text_config: Gemma3nTextConfig, ): super().__init__() @@ -427,8 +429,8 @@ class Gemma3nMultimodalEmbedder(nn.Module): def forward( self, - input_ids: Optional[torch.LongTensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, + input_ids: torch.LongTensor | None = None, + inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: """Embeds token ids or soft tokens for multimodal content into language model space. @@ -529,7 +531,7 @@ class Gemma3nForConditionalGeneration( def _parse_and_validate_image_input( self, **kwargs: object - ) -> Optional[Gemma3nImageInputs]: + ) -> Gemma3nImageInputs | None: pixel_values = kwargs.pop("pixel_values", None) image_embeds = kwargs.pop("image_embeds", None) # TODO is this the case? @@ -541,7 +543,7 @@ class Gemma3nForConditionalGeneration( def _parse_and_validate_audio_input( self, **kwargs: object - ) -> Optional[Gemma3nAudioInputs]: + ) -> Gemma3nAudioInputs | None: input_features_padded = kwargs.pop("input_features_padded", None) if input_features_padded is None: return None @@ -616,12 +618,15 @@ class Gemma3nForConditionalGeneration( ) audio_features = self.embed_audio(inputs_embeds=audio_outputs) - # ruff: noqa - # The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the - # text to account for this. However, the audio preprocessing and encoder do not guarantee they will - # produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens - # depending on the length of the longest audio input in the batch. When we encounter this situation, we pad - # the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab. + # The Gemma3nProcessor expects all audio will be 30s in length and + # inserts 188 audio soft tokens into the text to account for this. + # However, the audio preprocessing and encoder do not guarantee they + # will produce exactly 188 soft tokens; they may produce fewer tokens + # (for shorter audio) or more tokens (for longer audio or due to + # BOA/EOA special tokens in the placeholder sequence). + # We handle both cases: + # - If fewer tokens: pad with the embedding of the last vocab token + # - If more tokens: truncate to the expected count # TODO precompute and cache padding audio_padding_toks = torch.tensor( [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device @@ -631,13 +636,18 @@ class Gemma3nForConditionalGeneration( audio_mask.unsqueeze(-1), audio_padding_embs, audio_features ) - audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape - extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501 - extra_padding_features = audio_padding_embs.expand( - audio_batch_size, extra_padding_tokens, audio_embed_dim + expected_tokens = self.config.audio_soft_tokens_per_image + audio_features, tokens_truncated = adjust_audio_features_to_expected_length( + audio_features, expected_tokens, audio_padding_embs ) + if tokens_truncated > 0: + logger.warning( + "Gemma3n audio encoder produced %d extra tokens. " + "Truncating to match placeholder count of %d.", + tokens_truncated, + expected_tokens, + ) - audio_features = torch.cat((audio_features, extra_padding_features), dim=1) # Return a list of embeddings instead of a batched tensor return audio_features.unbind(0) @@ -666,9 +676,9 @@ class Gemma3nForConditionalGeneration( def embed_input_ids( self, input_ids: torch.Tensor, - multimodal_embeddings: Optional[MultiModalEmbeddings] = None, + multimodal_embeddings: MultiModalEmbeddings | None = None, *, - is_multimodal: Optional[torch.Tensor] = None, + is_multimodal: torch.Tensor | None = None, handle_oov_mm_token: bool = False, ) -> torch.Tensor: # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache @@ -701,8 +711,8 @@ class Gemma3nForConditionalGeneration( self, input_ids: torch.Tensor, positions: torch.Tensor, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, **kwargs: object, ) -> IntermediateTensors: if intermediate_tensors is not None: @@ -729,7 +739,7 @@ class Gemma3nForConditionalGeneration( def compute_logits( self, hidden_states: torch.Tensor, - ) -> Optional[torch.Tensor]: + ) -> torch.Tensor | None: return self.language_model.compute_logits(hidden_states) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: @@ -747,7 +757,7 @@ class Gemma3nForConditionalGeneration( ) @classmethod - def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: + def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality == "image": return "" elif modality == "audio": @@ -761,10 +771,10 @@ class Gemma3nForConditionalGeneration( audio: np.ndarray, stt_config: SpeechToTextConfig, model_config: ModelConfig, - language: Optional[str], + language: str | None, task_type: Literal["transcribe", "translate"], request_prompt: str, - to_language: Optional[str], + to_language: str | None, ) -> PromptType: """ Gemma3n supports "free-form" transcription.