[Bugfix] Fix integer overflow in Gemma3n audio processing (#31657)
Signed-off-by: Jeremy Teboul <jeremyte@meta.com>
This commit is contained in:
@@ -2,14 +2,154 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.models.gemma3n_audio_utils import (
|
||||
adjust_audio_features_to_expected_length,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ....conftest import ImageTestAssets
|
||||
from ...utils import build_model_context
|
||||
|
||||
# Gemma3 (image) model
|
||||
GEMMA3_MODEL_ID = "google/gemma-3-4b-it"
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
|
||||
# Gemma3n (multimodal with audio) model
|
||||
GEMMA3N_MODEL_ID = "google/gemma-3n-E2B-it"
|
||||
|
||||
# Expected audio tokens for Gemma3n (audio_soft_tokens_per_image)
|
||||
GEMMA3N_EXPECTED_AUDIO_TOKENS = 188
|
||||
|
||||
|
||||
class TestGemma3nAudioTensorLogic:
|
||||
"""CPU-based tests for Gemma3n audio feature tensor manipulation.
|
||||
|
||||
These tests validate the padding/truncation logic in
|
||||
adjust_audio_features_to_expected_length() which fixes the
|
||||
integer overflow in _process_audio_input when audio_seq_len > 188.
|
||||
"""
|
||||
|
||||
def test_padding_when_audio_short(self):
|
||||
"""Test that short audio is padded to expected length."""
|
||||
batch_size, seq_len, embed_dim = 1, 100, 256
|
||||
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
|
||||
|
||||
audio_features = torch.randn(batch_size, seq_len, embed_dim)
|
||||
padding_embs = torch.zeros(1, 1, embed_dim)
|
||||
|
||||
result, tokens_truncated = adjust_audio_features_to_expected_length(
|
||||
audio_features, expected_tokens, padding_embs
|
||||
)
|
||||
|
||||
assert result.shape == (batch_size, expected_tokens, embed_dim)
|
||||
assert tokens_truncated == 0
|
||||
# First 100 tokens should be original, rest should be padding (zeros)
|
||||
assert torch.allclose(result[:, :seq_len, :], audio_features)
|
||||
assert torch.allclose(
|
||||
result[:, seq_len:, :],
|
||||
torch.zeros(batch_size, expected_tokens - seq_len, embed_dim),
|
||||
)
|
||||
|
||||
def test_truncation_when_audio_long(self):
|
||||
"""Test that long audio is truncated to expected length.
|
||||
|
||||
This is the key test for the overflow fix. Previously, when
|
||||
audio_seq_len > expected_tokens, the code would compute a negative
|
||||
padding value causing: RuntimeError: numel: integer multiplication overflow
|
||||
"""
|
||||
batch_size, seq_len, embed_dim = 1, 192, 256 # 192 > 188
|
||||
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
|
||||
|
||||
audio_features = torch.randn(batch_size, seq_len, embed_dim)
|
||||
padding_embs = torch.zeros(1, 1, embed_dim)
|
||||
|
||||
result, tokens_truncated = adjust_audio_features_to_expected_length(
|
||||
audio_features, expected_tokens, padding_embs
|
||||
)
|
||||
|
||||
assert result.shape == (batch_size, expected_tokens, embed_dim)
|
||||
assert tokens_truncated == seq_len - expected_tokens # 192 - 188 = 4
|
||||
# Result should be first 188 tokens of original
|
||||
assert torch.allclose(result, audio_features[:, :expected_tokens, :])
|
||||
|
||||
def test_no_change_when_exact_length(self):
|
||||
"""Test that exact-length audio passes through unchanged."""
|
||||
batch_size, embed_dim = 1, 256
|
||||
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
|
||||
|
||||
audio_features = torch.randn(batch_size, expected_tokens, embed_dim)
|
||||
padding_embs = torch.zeros(1, 1, embed_dim)
|
||||
|
||||
result, tokens_truncated = adjust_audio_features_to_expected_length(
|
||||
audio_features, expected_tokens, padding_embs
|
||||
)
|
||||
|
||||
assert result.shape == audio_features.shape
|
||||
assert tokens_truncated == 0
|
||||
assert torch.allclose(result, audio_features)
|
||||
|
||||
def test_original_bug_would_fail(self):
|
||||
"""Verify the original buggy implementation would cause overflow.
|
||||
|
||||
The original code always tried to pad, which fails when
|
||||
audio_seq_len > expected_tokens because expand() gets negative size.
|
||||
"""
|
||||
batch_size, seq_len, embed_dim = 1, 192, 256
|
||||
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
|
||||
|
||||
padding_embs = torch.zeros(1, 1, embed_dim)
|
||||
|
||||
# Original buggy logic (always pads, never truncates)
|
||||
extra_padding_tokens = expected_tokens - seq_len # = -4 (negative!)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
# This should fail with negative size error
|
||||
padding_embs.expand(batch_size, extra_padding_tokens, embed_dim)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seq_len",
|
||||
[50, 100, 150, 187, 188, 189, 192, 200, 300],
|
||||
)
|
||||
def test_various_audio_lengths(self, seq_len: int):
|
||||
"""Test padding/truncation with various audio lengths."""
|
||||
batch_size, embed_dim = 1, 256
|
||||
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
|
||||
|
||||
audio_features = torch.randn(batch_size, seq_len, embed_dim)
|
||||
padding_embs = torch.zeros(1, 1, embed_dim)
|
||||
|
||||
# Should not raise any errors
|
||||
result, tokens_truncated = adjust_audio_features_to_expected_length(
|
||||
audio_features, expected_tokens, padding_embs
|
||||
)
|
||||
|
||||
# Output should always be expected_tokens length
|
||||
assert result.shape == (batch_size, expected_tokens, embed_dim)
|
||||
|
||||
# Verify truncation count is correct
|
||||
if seq_len > expected_tokens:
|
||||
assert tokens_truncated == seq_len - expected_tokens
|
||||
else:
|
||||
assert tokens_truncated == 0
|
||||
|
||||
def test_batch_processing(self):
|
||||
"""Test that batch processing works correctly."""
|
||||
batch_size, seq_len, embed_dim = 4, 192, 256
|
||||
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
|
||||
|
||||
audio_features = torch.randn(batch_size, seq_len, embed_dim)
|
||||
padding_embs = torch.zeros(1, 1, embed_dim)
|
||||
|
||||
result, tokens_truncated = adjust_audio_features_to_expected_length(
|
||||
audio_features, expected_tokens, padding_embs
|
||||
)
|
||||
|
||||
assert result.shape == (batch_size, expected_tokens, embed_dim)
|
||||
assert tokens_truncated == seq_len - expected_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", [GEMMA3_MODEL_ID])
|
||||
def test_get_image_size_with_most_features(
|
||||
image_assets: ImageTestAssets, model_id: str
|
||||
):
|
||||
|
||||
57
vllm/model_executor/models/gemma3n_audio_utils.py
Normal file
57
vllm/model_executor/models/gemma3n_audio_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Lightweight utility functions for Gemma3n audio processing.
|
||||
|
||||
This module is separate from gemma3n_mm.py to avoid heavy CUDA dependencies,
|
||||
making it testable without a full vLLM build.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_audio_features_to_expected_length(
|
||||
audio_features: torch.Tensor,
|
||||
expected_tokens: int,
|
||||
audio_padding_embs: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Adjust audio features to expected token length via padding or truncation.
|
||||
|
||||
The Gemma3nProcessor expects all audio will be ~30s in length and inserts
|
||||
a fixed number of audio soft tokens into the text. However, the audio
|
||||
preprocessing and encoder do not guarantee they will produce exactly that
|
||||
many soft tokens; they may produce fewer tokens (for shorter audio) or more
|
||||
tokens (for longer audio or due to BOA/EOA special tokens).
|
||||
|
||||
This function handles both cases:
|
||||
- If fewer tokens: pad with the provided padding embeddings
|
||||
- If more tokens: truncate to the expected count
|
||||
|
||||
Args:
|
||||
audio_features: Audio embeddings tensor of shape
|
||||
(batch_size, seq_len, embed_dim)
|
||||
expected_tokens: The expected number of audio tokens (e.g., 188)
|
||||
audio_padding_embs: Padding embeddings tensor of shape (1, 1, embed_dim)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- adjusted_features: Audio features adjusted to expected_tokens length
|
||||
- tokens_truncated: Number of tokens truncated (0 if padding was applied)
|
||||
"""
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
tokens_truncated = 0
|
||||
|
||||
if audio_seq_len < expected_tokens:
|
||||
# Pad to expected length with padding embeddings
|
||||
extra_padding_tokens = expected_tokens - audio_seq_len
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
||||
)
|
||||
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
||||
elif audio_seq_len > expected_tokens:
|
||||
# Truncate to expected length (audio encoder produced more tokens
|
||||
# than expected, e.g., due to longer audio or placeholder mismatch)
|
||||
tokens_truncated = audio_seq_len - expected_tokens
|
||||
audio_features = audio_features[:, :expected_tokens, :]
|
||||
|
||||
return audio_features, tokens_truncated
|
||||
@@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, Optional, Union, cast
|
||||
from typing import Annotated, Any, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
from transformers.models.gemma3n import (
|
||||
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
|
||||
from vllm.model_executor.models.gemma3n_audio_utils import (
|
||||
adjust_audio_features_to_expected_length,
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -105,12 +107,12 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None, "audio": None}
|
||||
|
||||
def get_max_tokens_per_item(
|
||||
self, seq_len: int, mm_counts: Mapping[str, int]
|
||||
) -> Optional[Mapping[str, int]]:
|
||||
) -> Mapping[str, int] | None:
|
||||
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
|
||||
|
||||
def get_image_repl(
|
||||
@@ -118,7 +120,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3nProcessor],
|
||||
processor: Gemma3nProcessor | None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for image tokens.
|
||||
@@ -136,7 +138,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
def get_audio_repl(
|
||||
self,
|
||||
*,
|
||||
processor: Optional[Gemma3nProcessor],
|
||||
processor: Gemma3nProcessor | None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for audio tokens.
|
||||
@@ -168,7 +170,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
@@ -387,7 +389,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
|
||||
multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
|
||||
text_config: Gemma3nTextConfig,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -427,8 +429,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds token ids or soft tokens for multimodal content into language model space.
|
||||
|
||||
@@ -529,7 +531,7 @@ class Gemma3nForConditionalGeneration(
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[Gemma3nImageInputs]:
|
||||
) -> Gemma3nImageInputs | None:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
# TODO is this the case?
|
||||
@@ -541,7 +543,7 @@ class Gemma3nForConditionalGeneration(
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[Gemma3nAudioInputs]:
|
||||
) -> Gemma3nAudioInputs | None:
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
@@ -616,12 +618,15 @@ class Gemma3nForConditionalGeneration(
|
||||
)
|
||||
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
|
||||
|
||||
# ruff: noqa
|
||||
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
|
||||
# text to account for this. However, the audio preprocessing and encoder do not guarantee they will
|
||||
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
|
||||
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
||||
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
|
||||
# The Gemma3nProcessor expects all audio will be 30s in length and
|
||||
# inserts 188 audio soft tokens into the text to account for this.
|
||||
# However, the audio preprocessing and encoder do not guarantee they
|
||||
# will produce exactly 188 soft tokens; they may produce fewer tokens
|
||||
# (for shorter audio) or more tokens (for longer audio or due to
|
||||
# BOA/EOA special tokens in the placeholder sequence).
|
||||
# We handle both cases:
|
||||
# - If fewer tokens: pad with the embedding of the last vocab token
|
||||
# - If more tokens: truncate to the expected count
|
||||
# TODO precompute and cache padding
|
||||
audio_padding_toks = torch.tensor(
|
||||
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
|
||||
@@ -631,13 +636,18 @@ class Gemma3nForConditionalGeneration(
|
||||
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
|
||||
)
|
||||
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
||||
expected_tokens = self.config.audio_soft_tokens_per_image
|
||||
audio_features, tokens_truncated = adjust_audio_features_to_expected_length(
|
||||
audio_features, expected_tokens, audio_padding_embs
|
||||
)
|
||||
if tokens_truncated > 0:
|
||||
logger.warning(
|
||||
"Gemma3n audio encoder produced %d extra tokens. "
|
||||
"Truncating to match placeholder count of %d.",
|
||||
tokens_truncated,
|
||||
expected_tokens,
|
||||
)
|
||||
|
||||
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
||||
# Return a list of embeddings instead of a batched tensor
|
||||
return audio_features.unbind(0)
|
||||
|
||||
@@ -666,9 +676,9 @@ class Gemma3nForConditionalGeneration(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||
@@ -701,8 +711,8 @@ class Gemma3nForConditionalGeneration(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
@@ -729,7 +739,7 @@ class Gemma3nForConditionalGeneration(
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
@@ -747,7 +757,7 @@ class Gemma3nForConditionalGeneration(
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality == "image":
|
||||
return "<image_soft_token>"
|
||||
elif modality == "audio":
|
||||
@@ -761,10 +771,10 @@ class Gemma3nForConditionalGeneration(
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str],
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
"""
|
||||
Gemma3n supports "free-form" transcription.
|
||||
|
||||
Reference in New Issue
Block a user