[Fix] Introduce audio channels spec (#31595)
Signed-off-by: Jeremy Teboul <jeremyte@meta.com>
This commit is contained in:
@@ -356,6 +356,44 @@ You can pass a tuple `(array, sampling_rate)` to the `'audio'` field of the mult
|
||||
|
||||
Full example: [examples/offline_inference/audio_language.py](../../examples/offline_inference/audio_language.py)
|
||||
|
||||
#### Automatic Audio Channel Normalization
|
||||
|
||||
vLLM automatically normalizes audio channels for models that require specific audio formats. When loading audio with libraries like `torchaudio`, stereo files return shape `[channels, time]`, but many audio models (particularly Whisper-based models) expect mono audio with shape `[time]`.
|
||||
|
||||
**Supported models with automatic mono conversion:**
|
||||
|
||||
- **Whisper** and all Whisper-based models
|
||||
- **Qwen2-Audio**
|
||||
- **Qwen2.5-Omni** / **Qwen3-Omni** (inherits from Qwen2.5-Omni)
|
||||
- **Ultravox**
|
||||
|
||||
For these models, vLLM automatically:
|
||||
|
||||
1. Detects if the model requires mono audio via the feature extractor
|
||||
2. Converts multi-channel audio to mono using channel averaging
|
||||
3. Handles both `(channels, time)` format (torchaudio) and `(time, channels)` format (soundfile)
|
||||
|
||||
**Example with stereo audio:**
|
||||
|
||||
```python
|
||||
import torchaudio
|
||||
from vllm import LLM
|
||||
|
||||
# Load stereo audio file - returns (channels, time) shape
|
||||
audio, sr = torchaudio.load("stereo_audio.wav")
|
||||
print(f"Original shape: {audio.shape}") # e.g., torch.Size([2, 16000])
|
||||
|
||||
# vLLM automatically converts to mono for Whisper-based models
|
||||
llm = LLM(model="openai/whisper-large-v3")
|
||||
|
||||
outputs = llm.generate({
|
||||
"prompt": "",
|
||||
"multi_modal_data": {"audio": (audio.numpy(), sr)},
|
||||
})
|
||||
```
|
||||
|
||||
No manual conversion is needed - vLLM handles the channel normalization automatically based on the model's requirements.
|
||||
|
||||
### Embedding Inputs
|
||||
|
||||
To input pre-computed embeddings belonging to a data type (i.e. image, video, or audio) directly to the language model,
|
||||
|
||||
@@ -28,10 +28,9 @@ def test_processor_with_audio_sample_rate(
|
||||
"""
|
||||
Test that vLLM's processor generates expected outputs with audio_sample_rate.
|
||||
|
||||
This validates the reviewer's request that we test the actual processor
|
||||
can handle different audio_sample_rate values and generate audio tokens.
|
||||
This validates that the processor correctly handles audio_sample_rate
|
||||
passed via hf_processor_mm_kwargs and generates audio tokens.
|
||||
"""
|
||||
# Setup: Build model context and processor
|
||||
ctx = build_model_context(
|
||||
model_id,
|
||||
limit_mm_per_prompt={"audio": 1, "image": 0, "video": 0},
|
||||
@@ -48,18 +47,17 @@ def test_processor_with_audio_sample_rate(
|
||||
prompt = "<|audio_start|><|audio_pad|><|audio_end|>"
|
||||
mm_data = {"audio": [(audio_data, audio_sample_rate)]}
|
||||
|
||||
# Execute: Apply processor with audio_sample_rate in mm_kwargs
|
||||
# Apply processor with audio_sample_rate in mm_kwargs
|
||||
hf_processor_mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": audio_sample_rate,
|
||||
}
|
||||
processed_inputs = processor.apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Assert: Verify audio tokens are generated
|
||||
# Verify audio tokens are generated
|
||||
hf_processor = processor.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
audio_token_id = tokenizer.convert_tokens_to_ids(hf_processor.audio_token)
|
||||
aud_tok_count = processed_inputs["prompt_token_ids"].count(audio_token_id)
|
||||
|
||||
# Audio should generate at least 1 token
|
||||
assert aud_tok_count >= 1, (
|
||||
f"Expected at least 1 audio token but got {aud_tok_count}. "
|
||||
f"sample_rate: {audio_sample_rate}Hz, duration: {audio_duration_sec}s"
|
||||
@@ -97,189 +95,10 @@ def test_longer_audio_generates_more_tokens(model_id: str) -> None:
|
||||
audio_token_id = tokenizer.convert_tokens_to_ids(hf_proc.audio_token)
|
||||
return processed["prompt_token_ids"].count(audio_token_id)
|
||||
|
||||
# Get token counts for different durations
|
||||
short_tokens = get_token_count(1.0)
|
||||
long_tokens = get_token_count(2.0)
|
||||
|
||||
# Longer audio should produce more tokens
|
||||
assert long_tokens > short_tokens, (
|
||||
f"Expected longer audio (2s) to have more tokens than shorter (1s). "
|
||||
f"Got short={short_tokens}, long={long_tokens}"
|
||||
)
|
||||
|
||||
|
||||
class TestQwen3OmniAudioSampleRatePreservation:
|
||||
"""Test that audio_sample_rate is preserved during kwargs restructuring.
|
||||
|
||||
These tests validate the fix for the audio_sample_rate bug in Qwen3 Omni
|
||||
where the parameter was lost during kwargs restructuring.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _process_kwargs(
|
||||
mm_kwargs: dict[str, Any],
|
||||
tok_kwargs: dict[str, Any],
|
||||
transformers_version: str = "4.57.0",
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Helper method to simulate kwargs processing logic from production code.
|
||||
|
||||
This method simulates the kwargs restructuring that happens in the
|
||||
Qwen3 Omni model when transformers < 4.58.0. By centralizing this
|
||||
logic, we make tests easier to maintain if the production logic changes.
|
||||
|
||||
Args:
|
||||
mm_kwargs: Multimodal kwargs (e.g., audio_sample_rate, truncation)
|
||||
tok_kwargs: Tokenizer kwargs (e.g., truncation)
|
||||
transformers_version: Version string to test against (default: "4.57.0")
|
||||
|
||||
Returns:
|
||||
Processed kwargs dictionary with restructured audio_kwargs and text_kwargs
|
||||
"""
|
||||
from packaging.version import Version
|
||||
|
||||
mm_kwargs_copy = dict(mm_kwargs)
|
||||
tok_kwargs_copy = dict(tok_kwargs)
|
||||
|
||||
if Version(transformers_version) < Version("4.58.0"):
|
||||
# Extract audio_sample_rate before restructuring (THE FIX)
|
||||
audio_sample_rate = mm_kwargs_copy.pop("audio_sample_rate", None)
|
||||
|
||||
# Restructure kwargs
|
||||
mm_kwargs_copy["audio_kwargs"] = {
|
||||
"truncation": mm_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
mm_kwargs_copy["text_kwargs"] = {
|
||||
"truncation": tok_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
|
||||
# Put audio_sample_rate into audio_kwargs (THE FIX)
|
||||
if audio_sample_rate is not None:
|
||||
mm_kwargs_copy["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate
|
||||
|
||||
return mm_kwargs_copy
|
||||
|
||||
def test_audio_sample_rate_preserved_in_audio_kwargs(self) -> None:
|
||||
"""
|
||||
Test that audio_sample_rate is moved from top-level mm_kwargs
|
||||
into audio_kwargs during kwargs restructuring.
|
||||
|
||||
This is the core fix: when transformers < 4.58.0, the code
|
||||
restructures kwargs into audio_kwargs and text_kwargs, and
|
||||
audio_sample_rate must be preserved in audio_kwargs.
|
||||
"""
|
||||
# Setup: Create mm_kwargs with audio_sample_rate at top level
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": 16000,
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {
|
||||
"truncation": False,
|
||||
}
|
||||
|
||||
# Execute: Process kwargs using helper method
|
||||
result = self._process_kwargs(mm_kwargs, tok_kwargs)
|
||||
|
||||
# Assert: Verify audio_sample_rate is in audio_kwargs
|
||||
assert "audio_kwargs" in result
|
||||
assert "audio_sample_rate" in result["audio_kwargs"]
|
||||
assert result["audio_kwargs"]["audio_sample_rate"] == 16000
|
||||
|
||||
# Assert: Verify truncation is also in audio_kwargs
|
||||
assert result["audio_kwargs"]["truncation"] is True
|
||||
|
||||
# Assert: Verify text_kwargs is created correctly
|
||||
assert "text_kwargs" in result
|
||||
assert result["text_kwargs"]["truncation"] is False
|
||||
|
||||
def test_audio_sample_rate_absent_when_not_provided(self) -> None:
|
||||
"""
|
||||
Test that when audio_sample_rate is not provided in mm_kwargs,
|
||||
the restructured audio_kwargs doesn't contain it.
|
||||
"""
|
||||
# Setup: Create mm_kwargs WITHOUT audio_sample_rate
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {
|
||||
"truncation": False,
|
||||
}
|
||||
|
||||
# Execute: Process kwargs using helper method
|
||||
result = self._process_kwargs(mm_kwargs, tok_kwargs)
|
||||
|
||||
# Assert: Verify audio_sample_rate is NOT in audio_kwargs
|
||||
assert "audio_kwargs" in result
|
||||
assert "audio_sample_rate" not in result["audio_kwargs"]
|
||||
|
||||
# Assert: Verify truncation is still in audio_kwargs
|
||||
assert result["audio_kwargs"]["truncation"] is True
|
||||
|
||||
@pytest.mark.parametrize("sample_rate", [8000, 16000, 22050, 24000, 44100, 48000])
|
||||
def test_various_audio_sample_rates_preserved(self, sample_rate: int) -> None:
|
||||
"""
|
||||
Test that various common audio sample rates are preserved.
|
||||
|
||||
Common sample rates:
|
||||
- 8000: Telephone quality
|
||||
- 16000: Wideband speech (Qwen3 Omni default)
|
||||
- 22050: Low-quality audio
|
||||
- 24000: High-quality speech
|
||||
- 44100: CD quality
|
||||
- 48000: Professional audio
|
||||
"""
|
||||
# Setup: Create mm_kwargs with specific sample rate
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": sample_rate,
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {"truncation": False}
|
||||
|
||||
# Execute: Process kwargs using helper method
|
||||
result = self._process_kwargs(mm_kwargs, tok_kwargs)
|
||||
|
||||
# Assert: Verify the specific sample rate is preserved
|
||||
assert result["audio_kwargs"]["audio_sample_rate"] == sample_rate
|
||||
|
||||
def test_kwargs_unchanged_for_newer_transformers_version(self) -> None:
|
||||
"""
|
||||
Test that kwargs structure remains unchanged for transformers >= 4.58.0.
|
||||
|
||||
This test ensures that when transformers version is 4.58.0 or higher,
|
||||
the kwargs restructuring is bypassed and audio_sample_rate remains
|
||||
at the top level as originally passed.
|
||||
"""
|
||||
from packaging.version import Version
|
||||
|
||||
# Setup: Create mm_kwargs with audio_sample_rate at top level
|
||||
mm_kwargs: dict[str, Any] = {
|
||||
"audio_sample_rate": 16000,
|
||||
"truncation": True,
|
||||
}
|
||||
tok_kwargs: dict[str, Any] = {
|
||||
"truncation": False,
|
||||
}
|
||||
|
||||
# Execute: Simulate with transformers >= 4.58.0
|
||||
mm_kwargs_copy = dict(mm_kwargs)
|
||||
tok_kwargs_copy = dict(tok_kwargs)
|
||||
|
||||
transformers_ver = "4.58.0" # Version that bypasses restructuring
|
||||
if Version(transformers_ver) < Version("4.58.0"):
|
||||
# This block should NOT execute for >= 4.58.0
|
||||
audio_sample_rate = mm_kwargs_copy.pop("audio_sample_rate", None)
|
||||
mm_kwargs_copy["audio_kwargs"] = {
|
||||
"truncation": mm_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
mm_kwargs_copy["text_kwargs"] = {
|
||||
"truncation": tok_kwargs_copy.pop("truncation", False)
|
||||
}
|
||||
if audio_sample_rate is not None:
|
||||
mm_kwargs_copy["audio_kwargs"]["audio_sample_rate"] = audio_sample_rate
|
||||
|
||||
# Assert: Verify kwargs structure is unchanged
|
||||
assert "audio_kwargs" not in mm_kwargs_copy
|
||||
assert "text_kwargs" not in mm_kwargs_copy
|
||||
assert mm_kwargs_copy["audio_sample_rate"] == 16000
|
||||
assert mm_kwargs_copy["truncation"] is True
|
||||
assert tok_kwargs_copy["truncation"] is False
|
||||
|
||||
@@ -7,10 +7,16 @@ from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.audio import (
|
||||
MONO_AUDIO_SPEC,
|
||||
PASSTHROUGH_AUDIO_SPEC,
|
||||
AudioMediaIO,
|
||||
AudioResampler,
|
||||
AudioSpec,
|
||||
ChannelReduction,
|
||||
normalize_audio,
|
||||
resample_audio_librosa,
|
||||
resample_audio_scipy,
|
||||
)
|
||||
@@ -137,3 +143,500 @@ def test_audio_media_io_encode_base64(dummy_audio):
|
||||
decoded = base64.b64decode(out)
|
||||
assert decoded == b"dummy_wav_data"
|
||||
mock_write.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Tests for normalize_audio function
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestNormalizeAudio:
|
||||
"""Tests for normalize_audio function with different specs."""
|
||||
|
||||
def test_passthrough_preserves_audio(self):
|
||||
"""Passthrough spec should not modify audio."""
|
||||
stereo = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
|
||||
result = normalize_audio(stereo, PASSTHROUGH_AUDIO_SPEC)
|
||||
np.testing.assert_array_equal(result, stereo)
|
||||
|
||||
def test_mono_spec_with_numpy_stereo(self):
|
||||
"""Mono spec should reduce stereo numpy array to 1D."""
|
||||
stereo = np.array([[1.0, 2.0], [-1.0, 0.0]], dtype=np.float32)
|
||||
result = normalize_audio(stereo, MONO_AUDIO_SPEC)
|
||||
assert result.ndim == 1
|
||||
np.testing.assert_array_almost_equal(result, [0.0, 1.0])
|
||||
|
||||
def test_mono_spec_with_torch_stereo(self):
|
||||
"""Mono spec should reduce stereo torch tensor to 1D."""
|
||||
stereo = torch.tensor([[1.0, 2.0], [-1.0, 0.0]])
|
||||
result = normalize_audio(stereo, MONO_AUDIO_SPEC)
|
||||
assert result.ndim == 1
|
||||
torch.testing.assert_close(result, torch.tensor([0.0, 1.0]))
|
||||
|
||||
def test_mono_passthrough_for_1d_numpy(self):
|
||||
"""1D numpy array should pass through unchanged with mono spec."""
|
||||
mono = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
result = normalize_audio(mono, MONO_AUDIO_SPEC)
|
||||
assert result.ndim == 1
|
||||
np.testing.assert_array_equal(result, mono)
|
||||
|
||||
def test_mono_passthrough_for_1d_torch(self):
|
||||
"""1D torch tensor should pass through unchanged with mono spec."""
|
||||
mono = torch.tensor([1.0, 2.0, 3.0])
|
||||
result = normalize_audio(mono, MONO_AUDIO_SPEC)
|
||||
assert result.ndim == 1
|
||||
torch.testing.assert_close(result, mono)
|
||||
|
||||
def test_first_channel_reduction(self):
|
||||
"""FIRST reduction should take only the first channel."""
|
||||
spec = AudioSpec(target_channels=1, channel_reduction=ChannelReduction.FIRST)
|
||||
stereo = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
result = normalize_audio(stereo, spec)
|
||||
np.testing.assert_array_equal(result, [1.0, 2.0])
|
||||
|
||||
def test_max_channel_reduction(self):
|
||||
"""MAX reduction should take max across channels."""
|
||||
spec = AudioSpec(target_channels=1, channel_reduction=ChannelReduction.MAX)
|
||||
stereo = np.array([[1.0, 4.0], [3.0, 2.0]], dtype=np.float32)
|
||||
result = normalize_audio(stereo, spec)
|
||||
np.testing.assert_array_equal(result, [3.0, 4.0])
|
||||
|
||||
def test_sum_channel_reduction(self):
|
||||
"""SUM reduction should sum across channels."""
|
||||
spec = AudioSpec(target_channels=1, channel_reduction=ChannelReduction.SUM)
|
||||
stereo = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)
|
||||
result = normalize_audio(stereo, spec)
|
||||
np.testing.assert_array_equal(result, [4.0, 6.0])
|
||||
|
||||
def test_invalid_3d_array_raises(self):
|
||||
"""3D arrays should raise ValueError."""
|
||||
audio_3d = np.random.randn(2, 3, 4).astype(np.float32)
|
||||
with pytest.raises(ValueError, match="Unsupported audio"):
|
||||
normalize_audio(audio_3d, MONO_AUDIO_SPEC)
|
||||
|
||||
def test_channel_expansion_raises(self):
|
||||
"""Expanding from mono to stereo should raise ValueError."""
|
||||
mono = np.array([1.0, 2.0, 3.0], dtype=np.float32)
|
||||
spec = AudioSpec(target_channels=2)
|
||||
with pytest.raises(ValueError, match="Cannot expand"):
|
||||
normalize_audio(mono, spec)
|
||||
|
||||
def test_time_channels_format_numpy(self):
|
||||
"""Audio in (time, channels) format should be transposed to (channels, time).
|
||||
|
||||
This handles the case where audio loaders like soundfile return
|
||||
(time, channels) format instead of (channels, time) like torchaudio.
|
||||
"""
|
||||
# Create audio in (time, channels) format: 1000 samples, 2 channels
|
||||
audio_time_channels = np.array(
|
||||
[[1.0, -1.0]] * 1000, # 1000 time steps, 2 channels
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert audio_time_channels.shape == (1000, 2) # (time, channels)
|
||||
|
||||
result = normalize_audio(audio_time_channels, MONO_AUDIO_SPEC)
|
||||
|
||||
# Should be reduced to mono 1D
|
||||
assert result.ndim == 1
|
||||
assert result.shape == (1000,)
|
||||
# Mean of [1.0, -1.0] at each time step should be 0.0
|
||||
np.testing.assert_array_almost_equal(result, np.zeros(1000))
|
||||
|
||||
def test_time_channels_format_torch(self):
|
||||
"""Torch tensor in (time, channels) format should be transposed."""
|
||||
# Create audio in (time, channels) format: 1000 samples, 2 channels
|
||||
audio_time_channels = torch.tensor(
|
||||
[[1.0, -1.0]] * 1000, # 1000 time steps, 2 channels
|
||||
)
|
||||
assert audio_time_channels.shape == (1000, 2) # (time, channels)
|
||||
|
||||
result = normalize_audio(audio_time_channels, MONO_AUDIO_SPEC)
|
||||
|
||||
# Should be reduced to mono 1D
|
||||
assert result.ndim == 1
|
||||
assert result.shape == (1000,)
|
||||
# Mean of [1.0, -1.0] at each time step should be 0.0
|
||||
torch.testing.assert_close(result, torch.zeros(1000))
|
||||
|
||||
def test_channels_time_format_preserved(self):
|
||||
"""Audio already in (channels, time) format should work correctly."""
|
||||
# Create audio in standard (channels, time) format: 2 channels, 1000 samples
|
||||
audio_channels_time = np.array(
|
||||
[[1.0] * 1000, [-1.0] * 1000], # 2 channels, 1000 time steps
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert audio_channels_time.shape == (2, 1000) # (channels, time)
|
||||
|
||||
result = normalize_audio(audio_channels_time, MONO_AUDIO_SPEC)
|
||||
|
||||
# Should be reduced to mono 1D
|
||||
assert result.ndim == 1
|
||||
assert result.shape == (1000,)
|
||||
# Mean of [1.0, -1.0] at each time step should be 0.0
|
||||
np.testing.assert_array_almost_equal(result, np.zeros(1000))
|
||||
|
||||
def test_ambiguous_square_audio_numpy(self):
|
||||
"""Square audio arrays (N, N) should use shape[0] > shape[1] heuristic.
|
||||
|
||||
For a square array, shape[0] == shape[1], so no transpose happens
|
||||
and we assume (channels, time) format.
|
||||
"""
|
||||
# Create square audio: 4 channels, 4 samples
|
||||
audio_square = np.array(
|
||||
[
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[5.0, 6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0, 12.0],
|
||||
[13.0, 14.0, 15.0, 16.0],
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert audio_square.shape == (4, 4)
|
||||
|
||||
result = normalize_audio(audio_square, MONO_AUDIO_SPEC)
|
||||
|
||||
# Should be reduced to mono 1D with mean across channels (axis 0)
|
||||
assert result.ndim == 1
|
||||
assert result.shape == (4,)
|
||||
# Mean across 4 channels: [1+5+9+13, 2+6+10+14, ...] / 4
|
||||
expected = np.array([7.0, 8.0, 9.0, 10.0])
|
||||
np.testing.assert_array_almost_equal(result, expected)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Tests for MultiModalDataParser integration with target_channels
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestMultiModalDataParserChannelNormalization:
|
||||
"""Tests for MultiModalDataParser.target_channels integration.
|
||||
|
||||
These tests verify that the target_channels parameter is properly used
|
||||
in the _parse_audio_data method to normalize audio channels.
|
||||
"""
|
||||
|
||||
def test_parser_normalizes_stereo_to_mono(self):
|
||||
"""Parser should normalize stereo to mono when target_channels=1."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Create parser with mono normalization enabled
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
# Create stereo audio (simulating torchaudio output)
|
||||
stereo_audio = np.array(
|
||||
[[1.0, 1.0, 1.0], [-1.0, -1.0, -1.0]], # 2 channels, 3 samples
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Parse audio data
|
||||
result = parser._parse_audio_data((stereo_audio, 16000))
|
||||
|
||||
# Check that result is mono (1D)
|
||||
audio_item = result.get(0)
|
||||
assert audio_item.ndim == 1, f"Expected 1D mono audio, got {audio_item.ndim}D"
|
||||
assert audio_item.shape == (3,), f"Expected shape (3,), got {audio_item.shape}"
|
||||
# Channel average of [1, 1, 1] and [-1, -1, -1] should be [0, 0, 0]
|
||||
np.testing.assert_array_almost_equal(audio_item, np.zeros(3))
|
||||
|
||||
def test_parser_preserves_stereo_when_target_channels_none(self):
|
||||
"""Parser should preserve stereo when target_channels=None."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Create parser without channel normalization
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=None,
|
||||
)
|
||||
|
||||
# Create stereo audio
|
||||
stereo_audio = np.array(
|
||||
[[1.0, 1.0, 1.0], [-1.0, -1.0, -1.0]],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Parse audio data
|
||||
result = parser._parse_audio_data((stereo_audio, 16000))
|
||||
|
||||
# Check that result preserves original shape (after resampling)
|
||||
audio_item = result.get(0)
|
||||
# When target_channels=None, stereo audio should be preserved
|
||||
assert audio_item.ndim == 2, f"Expected 2D stereo audio, got {audio_item.ndim}D"
|
||||
|
||||
def test_parser_mono_passthrough_when_target_channels_1(self):
|
||||
"""Parser should pass through mono audio unchanged when target_channels=1."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Create parser with mono normalization enabled
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
# Create mono audio (already 1D)
|
||||
mono_audio = np.random.randn(16000).astype(np.float32)
|
||||
|
||||
# Parse audio data
|
||||
result = parser._parse_audio_data((mono_audio, 16000))
|
||||
|
||||
# Check that result is still mono (1D)
|
||||
audio_item = result.get(0)
|
||||
assert audio_item.ndim == 1
|
||||
assert audio_item.shape == (16000,)
|
||||
|
||||
def test_parser_with_target_channels_2(self):
|
||||
"""Parser should reduce 6-channel to 2-channel when target_channels=2."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Create parser with stereo target
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=2,
|
||||
)
|
||||
|
||||
# Create 6-channel audio (5.1 surround)
|
||||
surround_audio = np.random.randn(6, 1000).astype(np.float32)
|
||||
|
||||
# Parse audio data
|
||||
result = parser._parse_audio_data((surround_audio, 16000))
|
||||
|
||||
# Check that result is stereo (2 channels)
|
||||
audio_item = result.get(0)
|
||||
assert audio_item.ndim == 2
|
||||
assert audio_item.shape[0] == 2 # 2 channels
|
||||
|
||||
|
||||
# ============================================================
|
||||
# End-to-End Audio Pipeline Tests
|
||||
# ============================================================
|
||||
|
||||
|
||||
class TestAudioPipelineE2E:
|
||||
"""End-to-end tests for audio normalization in the full pipeline.
|
||||
|
||||
These tests verify the complete flow from raw audio input through
|
||||
the MultiModalDataParser, simulating different audio loader formats.
|
||||
"""
|
||||
|
||||
def test_stereo_audio_normalized_to_mono_e2e(self):
|
||||
"""Full pipeline: stereo audio (torchaudio format) → mono output."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Simulate torchaudio output: (channels, time) format
|
||||
# Stereo audio with left channel = 1.0, right channel = -1.0
|
||||
stereo_torchaudio = np.array(
|
||||
[[1.0] * 16000, [-1.0] * 16000], # 2 channels, 1 second at 16kHz
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert stereo_torchaudio.shape == (2, 16000)
|
||||
|
||||
# Create parser with mono normalization (like Whisper models)
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
# Process audio through the parser
|
||||
result = parser._parse_audio_data((stereo_torchaudio, 16000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Verify output is mono 1D
|
||||
assert audio_output.ndim == 1, f"Expected 1D, got {audio_output.ndim}D"
|
||||
assert audio_output.shape == (16000,)
|
||||
|
||||
# Verify channel averaging: mean of [1.0, -1.0] = 0.0
|
||||
np.testing.assert_array_almost_equal(audio_output, np.zeros(16000), decimal=5)
|
||||
|
||||
def test_soundfile_format_normalized_to_mono_e2e(self):
|
||||
"""Full pipeline: soundfile format (time, channels) → mono output."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Simulate soundfile output: (time, channels) format
|
||||
# 16000 samples, 2 channels
|
||||
stereo_soundfile = np.array(
|
||||
[[0.5, -0.5]] * 16000, # Each row is [left, right]
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert stereo_soundfile.shape == (16000, 2)
|
||||
|
||||
# Create parser with mono normalization
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
# Process audio through the parser
|
||||
result = parser._parse_audio_data((stereo_soundfile, 16000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Verify output is mono 1D
|
||||
assert audio_output.ndim == 1, f"Expected 1D, got {audio_output.ndim}D"
|
||||
assert audio_output.shape == (16000,)
|
||||
|
||||
# Verify channel averaging: mean of [0.5, -0.5] = 0.0
|
||||
np.testing.assert_array_almost_equal(audio_output, np.zeros(16000), decimal=5)
|
||||
|
||||
def test_librosa_mono_passthrough_e2e(self):
|
||||
"""Full pipeline: librosa mono format → preserved as mono."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Simulate librosa output: already mono (time,) format
|
||||
mono_librosa = np.random.randn(16000).astype(np.float32)
|
||||
assert mono_librosa.shape == (16000,)
|
||||
|
||||
# Create parser with mono normalization
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
# Process audio through the parser
|
||||
result = parser._parse_audio_data((mono_librosa, 16000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Verify output is still mono 1D
|
||||
assert audio_output.ndim == 1
|
||||
assert audio_output.shape == (16000,)
|
||||
|
||||
# Verify audio content is preserved
|
||||
np.testing.assert_array_almost_equal(audio_output, mono_librosa)
|
||||
|
||||
def test_multichannel_5_1_surround_to_mono_e2e(self):
|
||||
"""Full pipeline: 5.1 surround (6 channels) → mono output."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Simulate 5.1 surround audio: 6 channels
|
||||
surround_audio = np.array(
|
||||
[
|
||||
[1.0] * 8000, # Front Left
|
||||
[2.0] * 8000, # Front Right
|
||||
[3.0] * 8000, # Center
|
||||
[4.0] * 8000, # LFE (subwoofer)
|
||||
[5.0] * 8000, # Rear Left
|
||||
[6.0] * 8000, # Rear Right
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
assert surround_audio.shape == (6, 8000)
|
||||
|
||||
# Create parser with mono normalization
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
# Process audio through the parser
|
||||
result = parser._parse_audio_data((surround_audio, 16000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Verify output is mono 1D
|
||||
assert audio_output.ndim == 1
|
||||
|
||||
# Verify channel averaging: mean of [1,2,3,4,5,6] = 3.5
|
||||
expected_value = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0) / 6
|
||||
np.testing.assert_array_almost_equal(
|
||||
audio_output, np.full(8000, expected_value), decimal=5
|
||||
)
|
||||
|
||||
def test_torch_tensor_input_e2e(self):
|
||||
"""Full pipeline: torch.Tensor stereo input → mono numpy output."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Simulate torch tensor input (from torchaudio)
|
||||
stereo_torch = torch.tensor(
|
||||
[[1.0] * 8000, [-1.0] * 8000], # 2 channels
|
||||
dtype=torch.float32,
|
||||
)
|
||||
assert stereo_torch.shape == (2, 8000)
|
||||
|
||||
# Create parser with mono normalization
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
# Process audio through the parser
|
||||
# Note: Parser expects numpy, so we convert first (simulating real usage)
|
||||
result = parser._parse_audio_data((stereo_torch.numpy(), 16000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Verify output is mono 1D numpy array
|
||||
assert audio_output.ndim == 1
|
||||
assert isinstance(audio_output, np.ndarray)
|
||||
|
||||
# Verify channel averaging
|
||||
np.testing.assert_array_almost_equal(audio_output, np.zeros(8000), decimal=5)
|
||||
|
||||
def test_passthrough_preserves_stereo_e2e(self):
|
||||
"""Full pipeline: stereo with target_channels=None → stereo preserved."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Stereo audio
|
||||
stereo_audio = np.array(
|
||||
[[1.0] * 8000, [-1.0] * 8000],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Create parser WITHOUT mono normalization (passthrough)
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=None, # Passthrough - no normalization
|
||||
)
|
||||
|
||||
# Process audio through the parser
|
||||
result = parser._parse_audio_data((stereo_audio, 16000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Verify output preserves stereo (2D)
|
||||
assert audio_output.ndim == 2
|
||||
assert audio_output.shape == (2, 8000)
|
||||
|
||||
def test_resampling_with_channel_normalization_e2e(self):
|
||||
"""Full pipeline: resample + channel normalize in single pass."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Stereo audio at 48kHz (common recording rate)
|
||||
stereo_48k = np.array(
|
||||
[[1.0] * 48000, [-1.0] * 48000], # 1 second at 48kHz
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# Create parser with both resampling and mono normalization
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000, # Resample to 16kHz
|
||||
target_channels=1, # Normalize to mono
|
||||
)
|
||||
|
||||
# Process audio through the parser
|
||||
result = parser._parse_audio_data((stereo_48k, 48000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Verify output is mono 1D at target sample rate
|
||||
assert audio_output.ndim == 1
|
||||
# After resampling from 48kHz to 16kHz, length should be ~16000
|
||||
assert audio_output.shape[0] == 16000
|
||||
|
||||
def test_very_short_audio_e2e(self):
|
||||
"""Full pipeline: very short audio (< 1 frame) handled correctly."""
|
||||
from vllm.multimodal.parse import MultiModalDataParser
|
||||
|
||||
# Very short stereo audio (10 samples)
|
||||
short_stereo = np.array(
|
||||
[[1.0] * 10, [-1.0] * 10],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
parser = MultiModalDataParser(
|
||||
target_sr=16000,
|
||||
target_channels=1,
|
||||
)
|
||||
|
||||
result = parser._parse_audio_data((short_stereo, 16000))
|
||||
audio_output = result.get(0)
|
||||
|
||||
# Should still produce mono output
|
||||
assert audio_output.ndim == 1
|
||||
assert audio_output.shape == (10,)
|
||||
np.testing.assert_array_almost_equal(audio_output, np.zeros(10))
|
||||
|
||||
@@ -226,6 +226,10 @@ class Qwen2_5OmniThinkerProcessingInfo(
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_target_channels(self) -> int:
|
||||
"""Return target audio channels for Qwen2.5 Omni models (mono)."""
|
||||
return 1
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None, "image": None, "video": None}
|
||||
|
||||
@@ -310,6 +314,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
|
||||
return Qwen2_5OmniThinkerMultiModalDataParser(
|
||||
spatial_merge_size=self.info.get_hf_config().vision_config.spatial_merge_size,
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=self.info.get_target_channels(),
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
|
||||
@@ -140,6 +140,10 @@ class Qwen2AudioProcessingInfo(BaseProcessingInfo):
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_target_channels(self) -> int:
|
||||
"""Return target audio channels for Qwen2 Audio models (mono)."""
|
||||
return 1
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None}
|
||||
|
||||
@@ -201,7 +205,10 @@ class Qwen2AudioMultiModalDataParser(MultiModalDataParser):
|
||||
class Qwen2AudioMultiModalProcessor(BaseMultiModalProcessor[Qwen2AudioProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return Qwen2AudioMultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
return Qwen2AudioMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=self.info.get_target_channels(),
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
|
||||
@@ -133,6 +133,10 @@ class UltravoxProcessingInfo(BaseProcessingInfo):
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_target_channels(self) -> int:
|
||||
"""Return target audio channels for Ultravox models (mono)."""
|
||||
return 1
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": None}
|
||||
|
||||
@@ -169,7 +173,10 @@ class UltravoxDummyInputsBuilder(BaseDummyInputsBuilder[UltravoxProcessingInfo])
|
||||
class UltravoxMultiModalProcessor(BaseMultiModalProcessor[UltravoxProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
return MultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=self.info.get_target_channels(),
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
|
||||
@@ -690,6 +690,10 @@ class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_target_channels(self) -> int:
|
||||
"""Return target audio channels for Whisper models (mono)."""
|
||||
return 1
|
||||
|
||||
def get_num_audio_tokens(self) -> int:
|
||||
return self.get_hf_config().max_source_positions
|
||||
|
||||
@@ -724,7 +728,10 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
||||
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
return MultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=self.info.get_target_channels(),
|
||||
)
|
||||
|
||||
@property
|
||||
def pad_dummy_encoder_prompt(self) -> bool:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
@@ -25,6 +27,136 @@ try:
|
||||
except ImportError:
|
||||
soundfile = PlaceholderModule("soundfile") # type: ignore[assignment]
|
||||
|
||||
# ============================================================
|
||||
|
||||
|
||||
class ChannelReduction(str, Enum):
|
||||
"""Method to reduce multi-channel audio to target channels."""
|
||||
|
||||
MEAN = "mean" # Average across channels (default, preserves energy balance)
|
||||
FIRST = "first" # Take first channel only
|
||||
MAX = "max" # Take max value across channels
|
||||
SUM = "sum" # Sum across channels
|
||||
|
||||
|
||||
@dataclass
|
||||
class AudioSpec:
|
||||
"""Specification for target audio format.
|
||||
|
||||
This dataclass defines the expected audio format for a model's feature
|
||||
extractor. It is used to normalize audio data before processing.
|
||||
|
||||
Attributes:
|
||||
target_channels: Number of output channels. None means passthrough
|
||||
(no normalization). 1 = mono, 2 = stereo, etc.
|
||||
channel_reduction: Method to reduce channels when input has more
|
||||
channels than target. Only used when reducing channels.
|
||||
"""
|
||||
|
||||
target_channels: int | None = 1
|
||||
channel_reduction: ChannelReduction = ChannelReduction.MEAN
|
||||
|
||||
@property
|
||||
def needs_normalization(self) -> bool:
|
||||
"""Whether audio normalization is needed."""
|
||||
return self.target_channels is not None
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.target_channels is None:
|
||||
return "AudioSpec(passthrough)"
|
||||
return (
|
||||
f"AudioSpec(channels={self.target_channels}, "
|
||||
f"reduction={self.channel_reduction.value})"
|
||||
)
|
||||
|
||||
|
||||
# Pre-defined specs for common use cases
|
||||
MONO_AUDIO_SPEC = AudioSpec(target_channels=1, channel_reduction=ChannelReduction.MEAN)
|
||||
PASSTHROUGH_AUDIO_SPEC = AudioSpec(target_channels=None)
|
||||
|
||||
|
||||
def normalize_audio(
|
||||
audio: npt.NDArray[np.floating] | torch.Tensor,
|
||||
spec: AudioSpec,
|
||||
) -> npt.NDArray[np.floating] | torch.Tensor:
|
||||
"""Normalize audio to the specified format.
|
||||
|
||||
This function handles channel reduction for multi-channel audio,
|
||||
supporting both numpy arrays and torch tensors.
|
||||
|
||||
Args:
|
||||
audio: Input audio data. Can be:
|
||||
- 1D array/tensor: (time,) - already mono
|
||||
- 2D array/tensor: (channels, time) - standard format from torchaudio
|
||||
- 2D array/tensor: (time, channels) - format from soundfile
|
||||
(will be auto-detected and transposed if time > channels)
|
||||
spec: AudioSpec defining the target format.
|
||||
|
||||
Returns:
|
||||
Normalized audio in the same type as input (numpy or torch).
|
||||
For mono output (target_channels=1), returns 1D array/tensor.
|
||||
|
||||
Raises:
|
||||
ValueError: If audio has unsupported dimensions or channel expansion
|
||||
is requested (e.g., mono to stereo).
|
||||
"""
|
||||
if not spec.needs_normalization:
|
||||
return audio
|
||||
|
||||
# Handle 1D audio (already mono)
|
||||
if audio.ndim == 1:
|
||||
if spec.target_channels == 1:
|
||||
return audio
|
||||
raise ValueError(f"Cannot expand mono audio to {spec.target_channels} channels")
|
||||
|
||||
# Handle 2D audio
|
||||
if audio.ndim != 2:
|
||||
raise ValueError(f"Unsupported audio shape: {audio.shape}. Expected 1D or 2D.")
|
||||
|
||||
# Auto-detect format: if shape[0] > shape[1], assume (time, channels)
|
||||
# This handles soundfile format where time dimension is typically much larger
|
||||
if audio.shape[0] > audio.shape[1]:
|
||||
# Transpose from (time, channels) to (channels, time)
|
||||
audio = audio.T if isinstance(audio, np.ndarray) else audio.T
|
||||
|
||||
num_channels = audio.shape[0]
|
||||
|
||||
# No reduction needed if already at target
|
||||
if num_channels == spec.target_channels:
|
||||
return audio
|
||||
|
||||
# Cannot expand channels
|
||||
if num_channels < spec.target_channels:
|
||||
raise ValueError(
|
||||
f"Cannot expand {num_channels} channels to {spec.target_channels}"
|
||||
)
|
||||
|
||||
# Reduce channels
|
||||
is_numpy = isinstance(audio, np.ndarray)
|
||||
|
||||
if spec.target_channels == 1:
|
||||
# Reduce to mono
|
||||
if spec.channel_reduction == ChannelReduction.MEAN:
|
||||
result = np.mean(audio, axis=0) if is_numpy else audio.mean(dim=0)
|
||||
elif spec.channel_reduction == ChannelReduction.FIRST:
|
||||
result = audio[0]
|
||||
elif spec.channel_reduction == ChannelReduction.MAX:
|
||||
result = np.max(audio, axis=0) if is_numpy else audio.max(dim=0).values
|
||||
elif spec.channel_reduction == ChannelReduction.SUM:
|
||||
result = np.sum(audio, axis=0) if is_numpy else audio.sum(dim=0)
|
||||
else:
|
||||
raise ValueError(f"Unknown reduction method: {spec.channel_reduction}")
|
||||
return result
|
||||
else:
|
||||
# Reduce to N channels (take first N and apply reduction if needed)
|
||||
# For now, just take first N channels
|
||||
return audio[: spec.target_channels]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Audio Resampling
|
||||
# ============================================================
|
||||
|
||||
|
||||
def resample_audio_librosa(
|
||||
audio: npt.NDArray[np.floating],
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing_extensions import assert_never
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.import_utils import LazyLoader
|
||||
|
||||
from .audio import AudioResampler
|
||||
from .audio import AudioResampler, AudioSpec, normalize_audio
|
||||
from .base import MediaWithBytes
|
||||
from .inputs import (
|
||||
AudioItem,
|
||||
@@ -456,6 +456,9 @@ class MultiModalDataParser:
|
||||
Args:
|
||||
target_sr (float, optional): Enables automatic resampling of audio
|
||||
items to the model's expected sampling rate.
|
||||
target_channels (int, optional): Target number of audio channels.
|
||||
If provided, normalizes audio to this many channels (e.g., 1 for mono).
|
||||
If None, audio channels are passed through unchanged.
|
||||
expected_hidden_size (int, optional): Expected hidden dimension for
|
||||
embedding inputs. If provided, validates that user-supplied
|
||||
embeddings have the correct hidden size to prevent crashes
|
||||
@@ -466,6 +469,7 @@ class MultiModalDataParser:
|
||||
self,
|
||||
*,
|
||||
target_sr: float | None = None,
|
||||
target_channels: int | None = None,
|
||||
audio_resample_method: Literal["librosa", "scipy"] = "librosa",
|
||||
video_needs_metadata: bool = False,
|
||||
expected_hidden_size: int | None = None,
|
||||
@@ -476,6 +480,7 @@ class MultiModalDataParser:
|
||||
target_sr=target_sr,
|
||||
method=audio_resample_method,
|
||||
)
|
||||
self.target_channels = target_channels
|
||||
self.video_needs_metadata = video_needs_metadata
|
||||
self.expected_hidden_size = expected_hidden_size
|
||||
|
||||
@@ -565,6 +570,11 @@ class MultiModalDataParser:
|
||||
else:
|
||||
new_audio = self.audio_resampler.resample(audio, orig_sr=orig_sr)
|
||||
|
||||
# Apply channel normalization if target_channels is set
|
||||
if self.target_channels is not None:
|
||||
spec = AudioSpec(target_channels=self.target_channels)
|
||||
new_audio = normalize_audio(new_audio, spec)
|
||||
|
||||
new_audios.append(new_audio)
|
||||
|
||||
return AudioProcessorItems(new_audios)
|
||||
|
||||
Reference in New Issue
Block a user