[Bugfix] Fix integer overflow in Gemma3n audio processing (#31657)

Signed-off-by: Jeremy Teboul <jeremyte@meta.com>
This commit is contained in:
Jeremy Teboul
2026-01-10 01:52:53 -08:00
committed by GitHub
parent 14fc7a68c7
commit 07286ec5a6
3 changed files with 239 additions and 32 deletions

View File

@@ -2,14 +2,154 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.models.gemma3n_audio_utils import (
adjust_audio_features_to_expected_length,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import ImageTestAssets
from ...utils import build_model_context
# Gemma3 (image) model
GEMMA3_MODEL_ID = "google/gemma-3-4b-it"
@pytest.mark.parametrize("model_id", ["google/gemma-3-4b-it"])
# Gemma3n (multimodal with audio) model
GEMMA3N_MODEL_ID = "google/gemma-3n-E2B-it"
# Expected audio tokens for Gemma3n (audio_soft_tokens_per_image)
GEMMA3N_EXPECTED_AUDIO_TOKENS = 188
class TestGemma3nAudioTensorLogic:
"""CPU-based tests for Gemma3n audio feature tensor manipulation.
These tests validate the padding/truncation logic in
adjust_audio_features_to_expected_length() which fixes the
integer overflow in _process_audio_input when audio_seq_len > 188.
"""
def test_padding_when_audio_short(self):
"""Test that short audio is padded to expected length."""
batch_size, seq_len, embed_dim = 1, 100, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == (batch_size, expected_tokens, embed_dim)
assert tokens_truncated == 0
# First 100 tokens should be original, rest should be padding (zeros)
assert torch.allclose(result[:, :seq_len, :], audio_features)
assert torch.allclose(
result[:, seq_len:, :],
torch.zeros(batch_size, expected_tokens - seq_len, embed_dim),
)
def test_truncation_when_audio_long(self):
"""Test that long audio is truncated to expected length.
This is the key test for the overflow fix. Previously, when
audio_seq_len > expected_tokens, the code would compute a negative
padding value causing: RuntimeError: numel: integer multiplication overflow
"""
batch_size, seq_len, embed_dim = 1, 192, 256 # 192 > 188
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == (batch_size, expected_tokens, embed_dim)
assert tokens_truncated == seq_len - expected_tokens # 192 - 188 = 4
# Result should be first 188 tokens of original
assert torch.allclose(result, audio_features[:, :expected_tokens, :])
def test_no_change_when_exact_length(self):
"""Test that exact-length audio passes through unchanged."""
batch_size, embed_dim = 1, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, expected_tokens, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == audio_features.shape
assert tokens_truncated == 0
assert torch.allclose(result, audio_features)
def test_original_bug_would_fail(self):
"""Verify the original buggy implementation would cause overflow.
The original code always tried to pad, which fails when
audio_seq_len > expected_tokens because expand() gets negative size.
"""
batch_size, seq_len, embed_dim = 1, 192, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
padding_embs = torch.zeros(1, 1, embed_dim)
# Original buggy logic (always pads, never truncates)
extra_padding_tokens = expected_tokens - seq_len # = -4 (negative!)
with pytest.raises(RuntimeError):
# This should fail with negative size error
padding_embs.expand(batch_size, extra_padding_tokens, embed_dim)
@pytest.mark.parametrize(
"seq_len",
[50, 100, 150, 187, 188, 189, 192, 200, 300],
)
def test_various_audio_lengths(self, seq_len: int):
"""Test padding/truncation with various audio lengths."""
batch_size, embed_dim = 1, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
# Should not raise any errors
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
# Output should always be expected_tokens length
assert result.shape == (batch_size, expected_tokens, embed_dim)
# Verify truncation count is correct
if seq_len > expected_tokens:
assert tokens_truncated == seq_len - expected_tokens
else:
assert tokens_truncated == 0
def test_batch_processing(self):
"""Test that batch processing works correctly."""
batch_size, seq_len, embed_dim = 4, 192, 256
expected_tokens = GEMMA3N_EXPECTED_AUDIO_TOKENS
audio_features = torch.randn(batch_size, seq_len, embed_dim)
padding_embs = torch.zeros(1, 1, embed_dim)
result, tokens_truncated = adjust_audio_features_to_expected_length(
audio_features, expected_tokens, padding_embs
)
assert result.shape == (batch_size, expected_tokens, embed_dim)
assert tokens_truncated == seq_len - expected_tokens
@pytest.mark.parametrize("model_id", [GEMMA3_MODEL_ID])
def test_get_image_size_with_most_features(
image_assets: ImageTestAssets, model_id: str
):