[Frontend] Add automatic language detection for Whisper transcription (#34342)

Signed-off-by: space_check <roman.vuskov@rwth-aachen.de>
Signed-off-by: Roman <45857014+spacecheck@users.noreply.github.com>
Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Roman
2026-02-21 13:49:41 +01:00
committed by GitHub
parent 272b535ab3
commit 98b0205c3c
5 changed files with 249 additions and 10 deletions

View File

@@ -273,3 +273,30 @@ async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
out_text = out["text"]
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
assert len(out_tokens) < 450 # ~Whisper max output len
@pytest.mark.asyncio
@pytest.mark.parametrize(
("fixture_name", "expected_lang", "expected_text"),
[
("mary_had_lamb", "en", ["Mary had a little lamb"]),
("foscolo", "it", ["zacinto", "sacre"]),
],
ids=["english", "italian"],
)
async def test_language_auto_detect(
whisper_client, fixture_name, expected_lang, expected_text, request
):
"""Auto-detect language when no language param is provided."""
audio_file = request.getfixturevalue(fixture_name)
transcription = await whisper_client.audio.transcriptions.create(
model=MODEL_NAME,
file=audio_file,
response_format="verbose_json",
temperature=0.0,
)
assert transcription.language == expected_lang
text_lower = transcription.text.lower()
assert any(word.lower() in text_lower for word in expected_text), (
f"Expected {expected_lang} text but got: {transcription.text}"
)

View File

@@ -111,6 +111,47 @@ def check_model_available(model: str) -> None:
model_info.check_transformers_version(on_fail="skip")
def test_parse_language_detection_output():
"""Unit test for WhisperForConditionalGeneration.parse_language_detection_output.
No GPU or model loading required.
"""
from unittest.mock import MagicMock
from vllm.model_executor.models.whisper import (
WhisperForConditionalGeneration,
)
cls = WhisperForConditionalGeneration
def make_tokenizer(return_value: str) -> MagicMock:
tok = MagicMock()
tok.decode = MagicMock(return_value=return_value)
return tok
# English
assert (
cls.parse_language_detection_output([50259], make_tokenizer("<|en|>")) == "en"
)
# German
assert (
cls.parse_language_detection_output([50261], make_tokenizer("<|de|>")) == "de"
)
# Unsupported language code
with pytest.raises(AssertionError):
cls.parse_language_detection_output([99999], make_tokenizer("<|xx|>"))
# No special token format
with pytest.raises(AssertionError):
cls.parse_language_detection_output([1], make_tokenizer("hello"))
# Empty token_ids
with pytest.raises((AssertionError, IndexError)):
cls.parse_language_detection_output([], make_tokenizer("anything"))
@pytest.mark.core_model
@pytest.mark.cpu_model
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])

View File

@@ -41,7 +41,10 @@ from vllm.exceptions import VLLMValidationError
from vllm.inputs import ProcessorInputs
from vllm.logger import init_logger
from vllm.logprobs import FlatLogprobs, Logprob
from vllm.model_executor.models import SupportsTranscription, supports_transcription
from vllm.model_executor.models import (
SupportsTranscription,
supports_transcription,
)
from vllm.outputs import RequestOutput
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
@@ -242,10 +245,57 @@ class OpenAISpeechToText(OpenAIServing):
model_cls = get_model_cls(self.model_config)
return cast(type[SupportsTranscription], model_cls)
async def _detect_language(
self,
audio_chunk: np.ndarray,
request_id: str,
) -> str:
"""Auto-detect the spoken language from an audio chunk.
Delegates prompt construction and output parsing to the model class
via ``get_language_detection_prompt`` and
``parse_language_detection_output``.
"""
from vllm.sampling_params import SamplingParams
prompt = self.model_cls.get_language_detection_prompt(
audio_chunk,
self.asr_config,
)
allowed_token_ids = self.model_cls.get_language_token_ids(
self.tokenizer,
)
sampling_params = SamplingParams(
max_tokens=1,
temperature=0.0,
allowed_token_ids=allowed_token_ids,
)
result_generator = self.engine_client.generate(
prompt,
sampling_params,
request_id,
)
final_output: RequestOutput
async for final_output in result_generator:
if final_output.finished:
break
token_ids = list(final_output.outputs[0].token_ids)
lang = self.model_cls.parse_language_detection_output(
token_ids,
self.tokenizer,
)
logger.info("Auto-detected language: '%s'", lang)
return lang
async def _preprocess_speech_to_text(
self,
request: SpeechToTextRequest,
audio_data: bytes,
request_id: str,
) -> tuple[list[ProcessorInputs], float]:
# Validate request
language = self.model_cls.validate_language(request.language)
@@ -274,6 +324,15 @@ class OpenAISpeechToText(OpenAIServing):
and duration > self.asr_config.max_audio_clip_s
)
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
if language is None and getattr(
self.model_cls, "supports_explicit_language_detection", False
):
language = await self._detect_language(
chunks[0], f"{request_id}-lang_detect"
)
request.language = language
parsed_prompts: list[DictPrompt] = []
for chunk in chunks:
# The model has control over the construction, as long as it
@@ -435,6 +494,7 @@ class OpenAISpeechToText(OpenAIServing):
engine_prompts, duration_s = await self._preprocess_speech_to_text(
request=request,
audio_data=audio_data,
request_id=request_id,
)
except ValueError as e:

View File

@@ -1111,6 +1111,16 @@ class SupportsTranscription(Protocol):
Enables the segment timestamp option for supported models by setting this to `True`.
"""
supports_explicit_language_detection: ClassVar[bool] = False
"""
Transcription models that require an explicit language detection step
(e.g. Whisper needs a separate forward pass to predict the language
token) should set this to ``True`` and implement
:meth:`get_language_detection_prompt` and
:meth:`parse_language_detection_output` and
:meth:`get_language_token_ids`.
"""
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# language codes in supported_languages
@@ -1206,6 +1216,46 @@ class SupportsTranscription(Protocol):
"""
return text
@classmethod
def get_language_detection_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
) -> PromptType:
"""Return a prompt that triggers language detection.
Only needs to be implemented when
``supports_explicit_language_detection`` is ``True``.
"""
raise NotImplementedError
@classmethod
def parse_language_detection_output(
cls,
token_ids: list[int],
tokenizer: object,
) -> str:
"""Parse the detected language from model output token IDs.
Only needs to be implemented when
``supports_explicit_language_detection`` is ``True``.
"""
raise NotImplementedError
@classmethod
def get_language_token_ids(
cls,
tokenizer: object,
) -> list[int] | None:
"""Return token IDs that represent valid language tokens.
Used to constrain language detection to only produce valid language tokens.
Only needs to be implemented when
``supports_explicit_language_detection`` is ``True``.
"""
raise NotImplementedError
@overload
def supports_transcription(

View File

@@ -64,7 +64,11 @@ from vllm.v1.attention.backend import (
AttentionType,
)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsTranscription,
)
from .utils import (
AutoWeightsLoader,
WeightsMapper,
@@ -784,7 +788,9 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
nn.Module,
SupportsTranscription,
SupportsMultiModal,
):
packed_modules_mapping = {
"self_attn.qkv_proj": [
@@ -802,20 +808,18 @@ class WhisperForConditionalGeneration(
# Whisper only supports audio-conditioned generation.
supports_transcription_only = True
supports_segment_timestamp = True
supports_explicit_language_detection = True
supported_languages = ISO639_1_SUPPORTED_LANGS
@classmethod
def validate_language(cls, language: str | None) -> str | None:
if language is None:
# TODO language should be optional and can be guessed.
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
logger.warning(
"Defaulting to language='en'. If you wish to transcribe "
"audio in a different language, pass the `language` field "
logger.debug(
"No language specified. Language will be auto-detected "
"from audio. To skip detection, pass the `language` field "
"in the TranscriptionRequest."
)
language = "en"
return None
return super().validate_language(language)
@classmethod
@@ -846,6 +850,63 @@ class WhisperForConditionalGeneration(
decoder_prompt=TextPrompt(prompt=decoder_text),
)
@classmethod
def get_language_token_ids(
cls,
tokenizer: object,
) -> list[int]:
"""Return token IDs for all supported language tokens.
Used with ``SamplingParams.allowed_token_ids`` to constrain
language detection to only produce valid language tokens.
"""
token_ids = [
tokenizer.convert_tokens_to_ids(f"<|{lang_code}|>")
for lang_code in cls.supported_languages
]
return token_ids
@classmethod
def get_language_detection_prompt(
cls,
audio: np.ndarray,
stt_config: SpeechToTextConfig,
) -> PromptType:
"""Return a prompt that elicits a single language token from Whisper.
Feed only ``<|startoftranscript|>`` as the decoder input so the model
predicts the most likely language token (e.g. ``<|de|>``).
"""
return ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt="",
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
),
decoder_prompt=TextPrompt(prompt="<|startoftranscript|>"),
)
@classmethod
def parse_language_detection_output(
cls,
token_ids: list[int],
tokenizer: object,
) -> str | None:
"""Parse the language token predicted by Whisper.
Decodes the first token ID and extracts the language code from the
``<|xx|>`` format. Expects a valid language token from constrained generation.
"""
decoded = tokenizer.decode(
[token_ids[0]],
skip_special_tokens=False,
)
# Whisper language tokens have the form <|xx|>
assert decoded.startswith("<|") and decoded.endswith("|>")
lang_code = decoded[2:-2]
assert lang_code in cls.supported_languages
return lang_code
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("audio"):