Improve HF qwen3_omni: preserve audio_sample_rate in kwargs restructuring (#29255)
Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com> Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
This commit is contained in:
285
tests/models/multimodal/processing/test_qwen3_omni.py
Normal file
285
tests/models/multimodal/processing/test_qwen3_omni.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Tests for Qwen3 Omni audio processing and sample rate handling."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
from ...utils import build_model_context
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen3-Omni-30B-A3B-Instruct"])
|
||||
@pytest.mark.parametrize(
|
||||
("audio_sample_rate", "audio_duration_sec"),
|
||||
[
|
||||
(16000, 1.0), # Native Whisper sample rate, 1 second
|
||||
(16000, 2.0), # Native Whisper sample rate, 2 seconds
|
||||
],
|
||||
)
|
||||
def test_processor_with_audio_sample_rate(
|
||||
model_id: str,
|
||||
audio_sample_rate: int,
|
||||
audio_duration_sec: float,
|
||||
) -> None:
|
||||
"""
|
||||
Test that vLLM's processor generates expected outputs with audio_sample_rate.
|
||||
|
||||
This validates the reviewer's request that we test the actual processor
|
||||
can handle different audio_sample_rate values and generate audio tokens.
|
||||
"""
|
||||
# Setup: Build model context and processor
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
limit_mm_per_prompt={"audio": 1, "image": 0, "video": 0},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
tokenizer = processor.info.get_tokenizer()
|
||||
|
||||
# Create audio data at the specified sample rate
|
||||
audio_length = int(audio_sample_rate * audio_duration_sec)
|
||||
rng = np.random.RandomState(42)
|
||||
audio_data = rng.rand(audio_length).astype(np.float32)
|
||||
|
||||
# Build prompt with audio placeholder
|
||||
prompt = "<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
mm_data = {"audio": [(audio_data, audio_sample_rate)]}
|
||||
|
||||
# Execute: Apply processor with audio_sample_rate in mm_kwargs
|
||||
hf_processor_mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": audio_sample_rate,
|
||||
}
|
||||
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Assert: Verify audio tokens are generated
|
||||
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
audio_token_id = tokenizer.convert_tokens_to_ids(hf_processor.audio_token)
|
||||
aud_tok_count = processed_inputs["prompt_token_ids"].count(audio_token_id)
|
||||
|
||||
# Audio should generate at least 1 token
|
||||
assert aud_tok_count >= 1, (
|
||||
f"Expected at least 1 audio token but got {aud_tok_count}. "
|
||||
f"sample_rate: {audio_sample_rate}Hz, duration: {audio_duration_sec}s"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen3-Omni-30B-A3B-Instruct"])
|
||||
def test_longer_audio_generates_more_tokens(model_id: str) -> None:
|
||||
"""
|
||||
Test that longer audio generates more tokens than shorter audio.
|
||||
|
||||
This validates that audio_sample_rate is being used correctly by checking
|
||||
that audio duration affects token count as expected.
|
||||
"""
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
limit_mm_per_prompt={"audio": 1, "image": 0, "video": 0},
|
||||
)
|
||||
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
|
||||
tokenizer = processor.info.get_tokenizer()
|
||||
|
||||
audio_sample_rate = 16000
|
||||
rng = np.random.RandomState(42)
|
||||
|
||||
def get_token_count(duration: float) -> int:
|
||||
audio_length = int(audio_sample_rate * duration)
|
||||
audio_data = rng.rand(audio_length).astype(np.float32)
|
||||
prompt = "<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
mm_data = {"audio": [(audio_data, audio_sample_rate)]}
|
||||
hf_processor_mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": audio_sample_rate,
|
||||
}
|
||||
processed = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
hf_proc = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
audio_token_id = tokenizer.convert_tokens_to_ids(hf_proc.audio_token)
|
||||
return processed["prompt_token_ids"].count(audio_token_id)
|
||||
|
||||
# Get token counts for different durations
|
||||
short_tokens = get_token_count(1.0)
|
||||
long_tokens = get_token_count(2.0)
|
||||
|
||||
# Longer audio should produce more tokens
|
||||
assert long_tokens > short_tokens, (
|
||||
f"Expected longer audio (2s) to have more tokens than shorter (1s). "
|
||||
f"Got short={short_tokens}, long={long_tokens}"
|
||||
)
|
||||
|
||||
|
||||
class TestQwen3OmniAudioSampleRatePreservation:
|
||||
"""Test that audio_sample_rate is preserved during kwargs restructuring.
|
||||
|
||||
These tests validate the fix for the audio_sample_rate bug in Qwen3 Omni
|
||||
where the parameter was lost during kwargs restructuring.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _process_kwargs(
|
||||
mm_kwargs: dict[str, Any],
|
||||
tok_kwargs: dict[str, Any],
|
||||
transformers_version: str = "4.57.0",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Helper method to simulate kwargs processing logic from production code.
|
||||
|
||||
This method simulates the kwargs restructuring that happens in the
|
||||
Qwen3 Omni model when transformers < 4.58.0. By centralizing this
|
||||
logic, we make tests easier to maintain if the production logic changes.
|
||||
|
||||
Args:
|
||||
mm_kwargs: Multimodal kwargs (e.g., audio_sample_rate, truncation)
|
||||
tok_kwargs: Tokenizer kwargs (e.g., truncation)
|
||||
transformers_version: Version string to test against (default: "4.57.0")
|
||||
|
||||
Returns:
|
||||
Processed kwargs dictionary with restructured audio_kwargs and text_kwargs
|
||||
"""
|
||||
from packaging.version import Version
|
||||
|
||||
mm_kwargs_copy = dict(mm_kwargs)
|
||||
tok_kwargs_copy = dict(tok_kwargs)
|
||||
|
||||
if Version(transformers_version) < Version("4.58.0"):
|
||||
# Extract audio_sample_rate before restructuring (THE FIX)
|
||||
audio_sample_rate = mm_kwargs_copy.pop("audio_sample_rate", None)
|
||||
|
||||
# Restructure kwargs
|
||||
mm_kwargs_copy["audio_kwargs"] = {
|
||||
"truncation": mm_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
mm_kwargs_copy["text_kwargs"] = {
|
||||
"truncation": tok_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
|
||||
# Put audio_sample_rate into audio_kwargs (THE FIX)
|
||||
if audio_sample_rate is not None:
|
||||
mm_kwargs_copy["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate
|
||||
|
||||
return mm_kwargs_copy
|
||||
|
||||
def test_audio_sample_rate_preserved_in_audio_kwargs(self) -> None:
|
||||
"""
|
||||
Test that audio_sample_rate is moved from top-level mm_kwargs
|
||||
into audio_kwargs during kwargs restructuring.
|
||||
|
||||
This is the core fix: when transformers < 4.58.0, the code
|
||||
restructures kwargs into audio_kwargs and text_kwargs, and
|
||||
audio_sample_rate must be preserved in audio_kwargs.
|
||||
"""
|
||||
# Setup: Create mm_kwargs with audio_sample_rate at top level
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": 16000,
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {
|
||||
"truncation": False,
|
||||
}
|
||||
|
||||
# Execute: Process kwargs using helper method
|
||||
result = self._process_kwargs(mm_kwargs, tok_kwargs)
|
||||
|
||||
# Assert: Verify audio_sample_rate is in audio_kwargs
|
||||
assert "audio_kwargs" in result
|
||||
assert "audio_sample_rate" in result["audio_kwargs"]
|
||||
assert result["audio_kwargs"]["audio_sample_rate"] == 16000
|
||||
|
||||
# Assert: Verify truncation is also in audio_kwargs
|
||||
assert result["audio_kwargs"]["truncation"] is True
|
||||
|
||||
# Assert: Verify text_kwargs is created correctly
|
||||
assert "text_kwargs" in result
|
||||
assert result["text_kwargs"]["truncation"] is False
|
||||
|
||||
def test_audio_sample_rate_absent_when_not_provided(self) -> None:
|
||||
"""
|
||||
Test that when audio_sample_rate is not provided in mm_kwargs,
|
||||
the restructured audio_kwargs doesn't contain it.
|
||||
"""
|
||||
# Setup: Create mm_kwargs WITHOUT audio_sample_rate
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {
|
||||
"truncation": False,
|
||||
}
|
||||
|
||||
# Execute: Process kwargs using helper method
|
||||
result = self._process_kwargs(mm_kwargs, tok_kwargs)
|
||||
|
||||
# Assert: Verify audio_sample_rate is NOT in audio_kwargs
|
||||
assert "audio_kwargs" in result
|
||||
assert "audio_sample_rate" not in result["audio_kwargs"]
|
||||
|
||||
# Assert: Verify truncation is still in audio_kwargs
|
||||
assert result["audio_kwargs"]["truncation"] is True
|
||||
|
||||
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 24000, 44100, 48000])
|
||||
def test_various_audio_sample_rates_preserved(self, sample_rate: int) -> None:
|
||||
"""
|
||||
Test that various common audio sample rates are preserved.
|
||||
|
||||
Common sample rates:
|
||||
- 8000: Telephone quality
|
||||
- 16000: Wideband speech (Qwen3 Omni default)
|
||||
- 22050: Low-quality audio
|
||||
- 24000: High-quality speech
|
||||
- 44100: CD quality
|
||||
- 48000: Professional audio
|
||||
"""
|
||||
# Setup: Create mm_kwargs with specific sample rate
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": sample_rate,
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {"truncation": False}
|
||||
|
||||
# Execute: Process kwargs using helper method
|
||||
result = self._process_kwargs(mm_kwargs, tok_kwargs)
|
||||
|
||||
# Assert: Verify the specific sample rate is preserved
|
||||
assert result["audio_kwargs"]["audio_sample_rate"] == sample_rate
|
||||
|
||||
def test_kwargs_unchanged_for_newer_transformers_version(self) -> None:
|
||||
"""
|
||||
Test that kwargs structure remains unchanged for transformers >= 4.58.0.
|
||||
|
||||
This test ensures that when transformers version is 4.58.0 or higher,
|
||||
the kwargs restructuring is bypassed and audio_sample_rate remains
|
||||
at the top level as originally passed.
|
||||
"""
|
||||
from packaging.version import Version
|
||||
|
||||
# Setup: Create mm_kwargs with audio_sample_rate at top level
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": 16000,
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {
|
||||
"truncation": False,
|
||||
}
|
||||
|
||||
# Execute: Simulate with transformers >= 4.58.0
|
||||
mm_kwargs_copy = dict(mm_kwargs)
|
||||
tok_kwargs_copy = dict(tok_kwargs)
|
||||
|
||||
transformers_ver = "4.58.0" # Version that bypasses restructuring
|
||||
if Version(transformers_ver) < Version("4.58.0"):
|
||||
# This block should NOT execute for >= 4.58.0
|
||||
audio_sample_rate = mm_kwargs_copy.pop("audio_sample_rate", None)
|
||||
mm_kwargs_copy["audio_kwargs"] = {
|
||||
"truncation": mm_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
mm_kwargs_copy["text_kwargs"] = {
|
||||
"truncation": tok_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
if audio_sample_rate is not None:
|
||||
mm_kwargs_copy["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate
|
||||
|
||||
# Assert: Verify kwargs structure is unchanged
|
||||
assert "audio_kwargs" not in mm_kwargs_copy
|
||||
assert "text_kwargs" not in mm_kwargs_copy
|
||||
assert mm_kwargs_copy["audio_sample_rate"] == 16000
|
||||
assert mm_kwargs_copy["truncation"] is True
|
||||
assert tok_kwargs_copy["truncation"] is False
|
||||
@@ -1021,9 +1021,8 @@ def test_hf_processor_init_kwargs(
|
||||
DummyProcessor, # type: ignore[arg-type]
|
||||
**inference_kwargs,
|
||||
)
|
||||
|
||||
for k, v in expected_kwargs.items():
|
||||
assert getattr(processor, k) == v
|
||||
assert processor.a == expected_kwargs["a"]
|
||||
assert processor.b == expected_kwargs["b"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_id", ["Qwen/Qwen2-VL-2B-Instruct"]) # Dummy
|
||||
|
||||
@@ -751,6 +751,9 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
mm_kwargs = dict(mm_kwargs)
|
||||
tok_kwargs = dict(tok_kwargs)
|
||||
if Version(TRANSFORMERS_VERSION) < Version("4.58.0"):
|
||||
# Extract audio_sample_rate before restructuring
|
||||
audio_sample_rate = mm_kwargs.pop("audio_sample_rate", None)
|
||||
|
||||
# move truncation to audio_kwargs level to avoid conflict
|
||||
# with tok_kwargs
|
||||
mm_kwargs["audio_kwargs"] = {
|
||||
@@ -760,6 +763,28 @@ class Qwen3OmniMoeThinkerMultiModalProcessor(
|
||||
"truncation": tok_kwargs.pop("truncation", False)
|
||||
}
|
||||
|
||||
# Validate and conditionally pass audio_sample_rate
|
||||
# WhisperFeatureExtractor has a fixed sampling rate, and vLLM's
|
||||
# audio loader already resamples audio to the target rate.
|
||||
# Only pass the value if it matches to avoid unexpected behavior.
|
||||
if audio_sample_rate is not None:
|
||||
expected_sr = feature_extractor.sampling_rate
|
||||
if audio_sample_rate != expected_sr:
|
||||
logger.warning(
|
||||
"[%s] audio_sample_rate mismatch: user provided %dHz "
|
||||
"but model expects %dHz. Ignoring user value. "
|
||||
"vLLM's audio loader already resampled to %dHz.",
|
||||
self.__class__.__name__,
|
||||
audio_sample_rate,
|
||||
expected_sr,
|
||||
expected_sr,
|
||||
)
|
||||
else:
|
||||
# Sample rate matches, safe to pass
|
||||
mm_kwargs["audio_kwargs"]["audio_sample_rate"] = (
|
||||
audio_sample_rate
|
||||
)
|
||||
|
||||
hf_inputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
|
||||
Reference in New Issue
Block a user