[Frontend] Remove torchcodec from audio dependency (#37061)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2026-03-15 13:15:59 +08:00
committed by GitHub
parent b3debb7e77
commit 6590a3ecda
3 changed files with 39 additions and 109 deletions

View File

@@ -977,7 +977,6 @@ setup(
"soundfile",
"mistral_common[audio]",
"av",
"torchcodec",
], # Required for audio processing
"video": [], # Kept for backwards compatibility
"flashinfer": [], # Kept for backwards compatibility

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import io
import math
import time
import zlib
@@ -35,7 +36,6 @@ from vllm.entrypoints.openai.speech_to_text.protocol import (
TranslationSegment,
TranslationStreamResponse,
)
from vllm.entrypoints.openai.speech_to_text.utils import load_audio_bytes
from vllm.entrypoints.utils import get_max_tokens
from vllm.exceptions import VLLMValidationError
from vllm.inputs import EncoderDecoderInputs, ProcessorInputs
@@ -43,6 +43,7 @@ from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription
from vllm.multimodal.audio import split_audio
from vllm.multimodal.media.audio import extract_audio_from_video_bytes
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
@@ -55,6 +56,19 @@ try:
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile as sf
except ImportError:
sf = PlaceholderModule("soundfile") # type: ignore[assignment]
# Public libsndfile error codes exposed via `soundfile.LibsndfileError.code`, soundfile
# being librosa's main backend. Used to validate if an audio loading error is due to a
# server error vs a client error (invalid audio file).
# 1 = unrecognised format (file is not a supported audio container)
# 3 = malformed file (corrupt or structurally invalid audio)
# 4 = unsupported encoding (codec not supported by this libsndfile build)
_BAD_SF_CODES = {1, 3, 4}
SpeechToTextResponse: TypeAlias = TranscriptionResponse | TranslationResponse
SpeechToTextResponseVerbose: TypeAlias = (
TranscriptionResponseVerbose | TranslationResponseVerbose
@@ -198,7 +212,30 @@ class OpenAISpeechToText(OpenAIServing):
# transparently falls back to ffmpeg via an in-memory fd.
# NOTE resample to model SR here for efficiency. This is also a
# pre-requisite for chunking, as it assumes Whisper SR.
y, sr = load_audio_bytes(audio_data, sr=self.asr_config.sample_rate)
try:
with io.BytesIO(audio_data) as buf:
y, sr = librosa.load(buf, sr=self.asr_config.sample_rate) # type: ignore[return-value]
except sf.LibsndfileError as exc:
# Only fall back for known format-detection failures.
# Re-raise anything else (e.g. corrupt but recognised format).
if exc.code not in _BAD_SF_CODES:
raise
logger.debug(
"librosa/soundfile could not decode audio from BytesIO "
"(code=%s: %s); falling back to pyav in-process decode",
exc.code,
exc,
)
try:
native_y, native_sr = extract_audio_from_video_bytes(audio_data)
sr = self.asr_config.sample_rate
y = librosa.resample(native_y, orig_sr=native_sr, target_sr=sr)
except Exception as pyav_exc:
logger.debug(
"pyAV fallback also failed: %s",
pyav_exc,
)
raise ValueError("Invalid or unsupported audio file.") from pyav_exc
duration = librosa.get_duration(y=y, sr=sr)
do_split_audio = (

View File

@@ -1,106 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Audio decoding utilities for the speech-to-text endpoints."""
import io
import numpy as np
import torchaudio
from vllm.logger import init_logger
from vllm.utils.import_utils import PlaceholderModule
try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]
try:
import soundfile as sf
except ImportError:
sf = PlaceholderModule("soundfile") # type: ignore[assignment]
logger = init_logger(__name__)
# Public libsndfile error codes exposed via ``soundfile.LibsndfileError.code``.
# soundfile is librosa's primary backend. These codes indicate that the audio
# data itself is problematic (unrecognised container, corrupt file, or
# unsupported encoding) rather than a transient server error.
# 1 = unrecognised format, 3 = malformed file, 4 = unsupported encoding
_BAD_SF_CODES = {1, 3, 4}
def _decode_audio_bytes_torchaudio(
audio_data: bytes,
sr: int,
) -> tuple[np.ndarray, int]:
"""Decode audio bytes to mono float32 PCM via torchaudio, in-process.
``torchaudio.load`` (backed by TorchCodec / FFmpeg) can decode
container formats (MP4, M4A, WebM) directly from a ``BytesIO``
buffer without spawning a subprocess. The decoded waveform is
down-mixed to mono and resampled to *sr* Hz, matching the return
convention of ``librosa.load``.
"""
buf = io.BytesIO(audio_data)
waveform, orig_sr = torchaudio.load(buf)
# Down-mix to mono (average across channels).
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample to the target sample rate when necessary.
if orig_sr != sr:
waveform = torchaudio.functional.resample(
waveform, orig_freq=orig_sr, new_freq=sr
)
# Squeeze channel dim → 1-D float32 numpy array (same as librosa.load).
y = waveform.squeeze(0).numpy()
if y.size == 0:
raise RuntimeError(
"torchaudio produced no audio samples (file may be empty or corrupt)"
)
return y, sr
def load_audio_bytes(
audio_data: bytes,
sr: int | float,
) -> tuple[np.ndarray, int]:
"""Load audio from raw bytes, with an in-process torchaudio fallback.
First tries ``librosa.load(BytesIO(...))`` which works for formats
that *soundfile* can auto-detect (WAV, FLAC, MP3, OGG, ...). If
that fails with a ``LibsndfileError`` indicating an unrecognised or
unsupported format (typically container formats like MP4/M4A/WebM),
the bytes are decoded in-process via ``torchaudio`` (backed by
TorchCodec / FFmpeg) which handles these containers natively.
"""
sr = int(sr)
# Fast path: librosa + soundfile (works for most formats).
try:
with io.BytesIO(audio_data) as buf:
return librosa.load(buf, sr=sr) # type: ignore[return-value]
except sf.LibsndfileError as exc:
# Only fall back for known format-detection failures.
# Re-raise anything else (e.g. corrupt but recognised format).
if exc.code not in _BAD_SF_CODES:
raise
logger.debug(
"librosa/soundfile could not decode audio from BytesIO "
"(code=%s: %s); falling back to torchaudio in-process decode",
exc.code,
exc,
)
# Fallback: torchaudio in-process decode (no subprocess overhead).
try:
return _decode_audio_bytes_torchaudio(audio_data, sr)
except Exception as ta_exc:
logger.debug(
"torchaudio fallback also failed: %s",
ta_exc,
)
raise ValueError("Invalid or unsupported audio file.") from ta_exc