[Bugfix] Fix integer overflow in Gemma3n audio processing (#31657)

Signed-off-by: Jeremy Teboul <jeremyte@meta.com>
This commit is contained in:
Jeremy Teboul
2026-01-10 01:52:53 -08:00
committed by GitHub
parent 14fc7a68c7
commit 07286ec5a6
3 changed files with 239 additions and 32 deletions

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Lightweight utility functions for Gemma3n audio processing.
This module is separate from gemma3n_mm.py to avoid heavy CUDA dependencies,
making it testable without a full vLLM build.
"""
import torch
def adjust_audio_features_to_expected_length(
audio_features: torch.Tensor,
expected_tokens: int,
audio_padding_embs: torch.Tensor,
) -> tuple[torch.Tensor, int]:
"""Adjust audio features to expected token length via padding or truncation.
The Gemma3nProcessor expects all audio will be ~30s in length and inserts
a fixed number of audio soft tokens into the text. However, the audio
preprocessing and encoder do not guarantee they will produce exactly that
many soft tokens; they may produce fewer tokens (for shorter audio) or more
tokens (for longer audio or due to BOA/EOA special tokens).
This function handles both cases:
- If fewer tokens: pad with the provided padding embeddings
- If more tokens: truncate to the expected count
Args:
audio_features: Audio embeddings tensor of shape
(batch_size, seq_len, embed_dim)
expected_tokens: The expected number of audio tokens (e.g., 188)
audio_padding_embs: Padding embeddings tensor of shape (1, 1, embed_dim)
Returns:
Tuple of:
- adjusted_features: Audio features adjusted to expected_tokens length
- tokens_truncated: Number of tokens truncated (0 if padding was applied)
"""
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
tokens_truncated = 0
if audio_seq_len < expected_tokens:
# Pad to expected length with padding embeddings
extra_padding_tokens = expected_tokens - audio_seq_len
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
elif audio_seq_len > expected_tokens:
# Truncate to expected length (audio encoder produced more tokens
# than expected, e.g., due to longer audio or placeholder mismatch)
tokens_truncated = audio_seq_len - expected_tokens
audio_features = audio_features[:, :expected_tokens, :]
return audio_features, tokens_truncated

View File

@@ -1,11 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, Optional, Union, cast
from typing import Annotated, Any, Literal, cast
import numpy as np
import torch
from torch import nn
from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n import (
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
from vllm.model_executor.models.gemma3n_audio_utils import (
adjust_audio_features_to_expected_length,
)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -105,12 +107,12 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object):
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "audio": None}
def get_max_tokens_per_item(
self, seq_len: int, mm_counts: Mapping[str, int]
) -> Optional[Mapping[str, int]]:
) -> Mapping[str, int] | None:
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
def get_image_repl(
@@ -118,7 +120,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3nProcessor],
processor: Gemma3nProcessor | None,
) -> str:
"""
Get the replacement text for image tokens.
@@ -136,7 +138,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_audio_repl(
self,
*,
processor: Optional[Gemma3nProcessor],
processor: Gemma3nProcessor | None,
) -> str:
"""
Get the replacement text for audio tokens.
@@ -168,7 +170,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
@@ -387,7 +389,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
def __init__(
self,
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
text_config: Gemma3nTextConfig,
):
super().__init__()
@@ -427,8 +429,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
input_ids: torch.LongTensor | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
"""Embeds token ids or soft tokens for multimodal content into language model space.
@@ -529,7 +531,7 @@ class Gemma3nForConditionalGeneration(
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Gemma3nImageInputs]:
) -> Gemma3nImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
# TODO is this the case?
@@ -541,7 +543,7 @@ class Gemma3nForConditionalGeneration(
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> Optional[Gemma3nAudioInputs]:
) -> Gemma3nAudioInputs | None:
input_features_padded = kwargs.pop("input_features_padded", None)
if input_features_padded is None:
return None
@@ -616,12 +618,15 @@ class Gemma3nForConditionalGeneration(
)
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
# ruff: noqa
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
# text to account for this. However, the audio preprocessing and encoder do not guarantee they will
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
# The Gemma3nProcessor expects all audio will be 30s in length and
# inserts 188 audio soft tokens into the text to account for this.
# However, the audio preprocessing and encoder do not guarantee they
# will produce exactly 188 soft tokens; they may produce fewer tokens
# (for shorter audio) or more tokens (for longer audio or due to
# BOA/EOA special tokens in the placeholder sequence).
# We handle both cases:
# - If fewer tokens: pad with the embedding of the last vocab token
# - If more tokens: truncate to the expected count
# TODO precompute and cache padding
audio_padding_toks = torch.tensor(
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
@@ -631,13 +636,18 @@ class Gemma3nForConditionalGeneration(
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
)
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
extra_padding_features = audio_padding_embs.expand(
audio_batch_size, extra_padding_tokens, audio_embed_dim
expected_tokens = self.config.audio_soft_tokens_per_image
audio_features, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, audio_padding_embs
)
if tokens_truncated > 0:
logger.warning(
"Gemma3n audio encoder produced %d extra tokens. "
"Truncating to match placeholder count of %d.",
tokens_truncated,
expected_tokens,
)
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
# Return a list of embeddings instead of a batched tensor
return audio_features.unbind(0)
@@ -666,9 +676,9 @@ class Gemma3nForConditionalGeneration(
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
@@ -701,8 +711,8 @@ class Gemma3nForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
@@ -729,7 +739,7 @@ class Gemma3nForConditionalGeneration(
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
@@ -747,7 +757,7 @@ class Gemma3nForConditionalGeneration(
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality == "image":
return "<image_soft_token>"
elif modality == "audio":
@@ -761,10 +771,10 @@ class Gemma3nForConditionalGeneration(
audio: np.ndarray,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
language: Optional[str],
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str],
to_language: str | None,
) -> PromptType:
"""
Gemma3n supports "free-form" transcription.