[Bugfix] Fix integer overflow in Gemma3n audio processing (#31657)
Signed-off-by: Jeremy Teboul <jeremyte@meta.com>
This commit is contained in:
57
vllm/model_executor/models/gemma3n_audio_utils.py
Normal file
57
vllm/model_executor/models/gemma3n_audio_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Lightweight utility functions for Gemma3n audio processing.
|
||||
|
||||
This module is separate from gemma3n_mm.py to avoid heavy CUDA dependencies,
|
||||
making it testable without a full vLLM build.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_audio_features_to_expected_length(
|
||||
audio_features: torch.Tensor,
|
||||
expected_tokens: int,
|
||||
audio_padding_embs: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Adjust audio features to expected token length via padding or truncation.
|
||||
|
||||
The Gemma3nProcessor expects all audio will be ~30s in length and inserts
|
||||
a fixed number of audio soft tokens into the text. However, the audio
|
||||
preprocessing and encoder do not guarantee they will produce exactly that
|
||||
many soft tokens; they may produce fewer tokens (for shorter audio) or more
|
||||
tokens (for longer audio or due to BOA/EOA special tokens).
|
||||
|
||||
This function handles both cases:
|
||||
- If fewer tokens: pad with the provided padding embeddings
|
||||
- If more tokens: truncate to the expected count
|
||||
|
||||
Args:
|
||||
audio_features: Audio embeddings tensor of shape
|
||||
(batch_size, seq_len, embed_dim)
|
||||
expected_tokens: The expected number of audio tokens (e.g., 188)
|
||||
audio_padding_embs: Padding embeddings tensor of shape (1, 1, embed_dim)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- adjusted_features: Audio features adjusted to expected_tokens length
|
||||
- tokens_truncated: Number of tokens truncated (0 if padding was applied)
|
||||
"""
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
tokens_truncated = 0
|
||||
|
||||
if audio_seq_len < expected_tokens:
|
||||
# Pad to expected length with padding embeddings
|
||||
extra_padding_tokens = expected_tokens - audio_seq_len
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
||||
)
|
||||
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
||||
elif audio_seq_len > expected_tokens:
|
||||
# Truncate to expected length (audio encoder produced more tokens
|
||||
# than expected, e.g., due to longer audio or placeholder mismatch)
|
||||
tokens_truncated = audio_seq_len - expected_tokens
|
||||
audio_features = audio_features[:, :expected_tokens, :]
|
||||
|
||||
return audio_features, tokens_truncated
|
||||
@@ -1,11 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Literal, Optional, Union, cast
|
||||
from typing import Annotated, Any, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from torch import nn
|
||||
from transformers import AutoModel, BatchFeature
|
||||
from transformers.models.gemma3n import (
|
||||
@@ -26,6 +25,9 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.model_executor.models.gemma3n import Gemma3nForCausalLM
|
||||
from vllm.model_executor.models.gemma3n_audio_utils import (
|
||||
adjust_audio_features_to_expected_length,
|
||||
)
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.whisper import ISO639_1_SUPPORTED_LANGS
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -105,12 +107,12 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None, "audio": None}
|
||||
|
||||
def get_max_tokens_per_item(
|
||||
self, seq_len: int, mm_counts: Mapping[str, int]
|
||||
) -> Optional[Mapping[str, int]]:
|
||||
) -> Mapping[str, int] | None:
|
||||
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
|
||||
|
||||
def get_image_repl(
|
||||
@@ -118,7 +120,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[Gemma3nProcessor],
|
||||
processor: Gemma3nProcessor | None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for image tokens.
|
||||
@@ -136,7 +138,7 @@ class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||
def get_audio_repl(
|
||||
self,
|
||||
*,
|
||||
processor: Optional[Gemma3nProcessor],
|
||||
processor: Gemma3nProcessor | None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the replacement text for audio tokens.
|
||||
@@ -168,7 +170,7 @@ class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
@@ -387,7 +389,7 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multimodal_config: Union[Gemma3nAudioConfig, Gemma3nVisionConfig],
|
||||
multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
|
||||
text_config: Gemma3nTextConfig,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -427,8 +429,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Embeds token ids or soft tokens for multimodal content into language model space.
|
||||
|
||||
@@ -529,7 +531,7 @@ class Gemma3nForConditionalGeneration(
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[Gemma3nImageInputs]:
|
||||
) -> Gemma3nImageInputs | None:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
# TODO is this the case?
|
||||
@@ -541,7 +543,7 @@ class Gemma3nForConditionalGeneration(
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[Gemma3nAudioInputs]:
|
||||
) -> Gemma3nAudioInputs | None:
|
||||
input_features_padded = kwargs.pop("input_features_padded", None)
|
||||
if input_features_padded is None:
|
||||
return None
|
||||
@@ -616,12 +618,15 @@ class Gemma3nForConditionalGeneration(
|
||||
)
|
||||
audio_features = self.embed_audio(inputs_embeds=audio_outputs)
|
||||
|
||||
# ruff: noqa
|
||||
# The Gemma3nProcessor expects all audio will be 30s in length and inserts 188 audio soft tokens into the
|
||||
# text to account for this. However, the audio preprocessing and encoder do not guarantee they will
|
||||
# produce 188 soft tokens; they will produce at most that many tokens, but they may produce fewer tokens
|
||||
# depending on the length of the longest audio input in the batch. When we encounter this situation, we pad
|
||||
# the audio feature out to 188 soft tokens with the embedding of the last token in the embed_audio vocab.
|
||||
# The Gemma3nProcessor expects all audio will be 30s in length and
|
||||
# inserts 188 audio soft tokens into the text to account for this.
|
||||
# However, the audio preprocessing and encoder do not guarantee they
|
||||
# will produce exactly 188 soft tokens; they may produce fewer tokens
|
||||
# (for shorter audio) or more tokens (for longer audio or due to
|
||||
# BOA/EOA special tokens in the placeholder sequence).
|
||||
# We handle both cases:
|
||||
# - If fewer tokens: pad with the embedding of the last vocab token
|
||||
# - If more tokens: truncate to the expected count
|
||||
# TODO precompute and cache padding
|
||||
audio_padding_toks = torch.tensor(
|
||||
[[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
|
||||
@@ -631,13 +636,18 @@ class Gemma3nForConditionalGeneration(
|
||||
audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
|
||||
)
|
||||
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
extra_padding_tokens = self.config.audio_soft_tokens_per_image - audio_seq_len # noqa: E501
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
||||
expected_tokens = self.config.audio_soft_tokens_per_image
|
||||
audio_features, tokens_truncated = adjust_audio_features_to_expected_length(
|
||||
audio_features, expected_tokens, audio_padding_embs
|
||||
)
|
||||
if tokens_truncated > 0:
|
||||
logger.warning(
|
||||
"Gemma3n audio encoder produced %d extra tokens. "
|
||||
"Truncating to match placeholder count of %d.",
|
||||
tokens_truncated,
|
||||
expected_tokens,
|
||||
)
|
||||
|
||||
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
||||
# Return a list of embeddings instead of a batched tensor
|
||||
return audio_features.unbind(0)
|
||||
|
||||
@@ -666,9 +676,9 @@ class Gemma3nForConditionalGeneration(
|
||||
def embed_input_ids(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
|
||||
@@ -701,8 +711,8 @@ class Gemma3nForConditionalGeneration(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
@@ -729,7 +739,7 @@ class Gemma3nForConditionalGeneration(
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
@@ -747,7 +757,7 @@ class Gemma3nForConditionalGeneration(
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality == "image":
|
||||
return "<image_soft_token>"
|
||||
elif modality == "audio":
|
||||
@@ -761,10 +771,10 @@ class Gemma3nForConditionalGeneration(
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
language: Optional[str],
|
||||
language: str | None,
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str],
|
||||
to_language: str | None,
|
||||
) -> PromptType:
|
||||
"""
|
||||
Gemma3n supports "free-form" transcription.
|
||||
|
||||
Reference in New Issue
Block a user