[Bugfix] Fix integer overflow in Gemma3n audio processing (#31657)
Signed-off-by: Jeremy Teboul <jeremyte@meta.com>
This commit is contained in:
57
vllm/model_executor/models/gemma3n_audio_utils.py
Normal file
57
vllm/model_executor/models/gemma3n_audio_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Lightweight utility functions for Gemma3n audio processing.
|
||||
|
||||
This module is separate from gemma3n_mm.py to avoid heavy CUDA dependencies,
|
||||
making it testable without a full vLLM build.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def adjust_audio_features_to_expected_length(
|
||||
audio_features: torch.Tensor,
|
||||
expected_tokens: int,
|
||||
audio_padding_embs: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, int]:
|
||||
"""Adjust audio features to expected token length via padding or truncation.
|
||||
|
||||
The Gemma3nProcessor expects all audio will be ~30s in length and inserts
|
||||
a fixed number of audio soft tokens into the text. However, the audio
|
||||
preprocessing and encoder do not guarantee they will produce exactly that
|
||||
many soft tokens; they may produce fewer tokens (for shorter audio) or more
|
||||
tokens (for longer audio or due to BOA/EOA special tokens).
|
||||
|
||||
This function handles both cases:
|
||||
- If fewer tokens: pad with the provided padding embeddings
|
||||
- If more tokens: truncate to the expected count
|
||||
|
||||
Args:
|
||||
audio_features: Audio embeddings tensor of shape
|
||||
(batch_size, seq_len, embed_dim)
|
||||
expected_tokens: The expected number of audio tokens (e.g., 188)
|
||||
audio_padding_embs: Padding embeddings tensor of shape (1, 1, embed_dim)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- adjusted_features: Audio features adjusted to expected_tokens length
|
||||
- tokens_truncated: Number of tokens truncated (0 if padding was applied)
|
||||
"""
|
||||
audio_batch_size, audio_seq_len, audio_embed_dim = audio_features.shape
|
||||
tokens_truncated = 0
|
||||
|
||||
if audio_seq_len < expected_tokens:
|
||||
# Pad to expected length with padding embeddings
|
||||
extra_padding_tokens = expected_tokens - audio_seq_len
|
||||
extra_padding_features = audio_padding_embs.expand(
|
||||
audio_batch_size, extra_padding_tokens, audio_embed_dim
|
||||
)
|
||||
audio_features = torch.cat((audio_features, extra_padding_features), dim=1)
|
||||
elif audio_seq_len > expected_tokens:
|
||||
# Truncate to expected length (audio encoder produced more tokens
|
||||
# than expected, e.g., due to longer audio or placeholder mismatch)
|
||||
tokens_truncated = audio_seq_len - expected_tokens
|
||||
audio_features = audio_features[:, :expected_tokens, :]
|
||||
|
||||
return audio_features, tokens_truncated
|
||||
Reference in New Issue
Block a user