[Frontend] Add automatic language detection for Whisper transcription (#34342)
Signed-off-by: space_check <roman.vuskov@rwth-aachen.de> Signed-off-by: Roman <45857014+spacecheck@users.noreply.github.com> Co-authored-by: Nicolò Lucchesi <nicolo.lucchesi@gmail.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -273,3 +273,30 @@ async def test_audio_with_max_tokens(whisper_client, mary_had_lamb):
|
||||
out_text = out["text"]
|
||||
out_tokens = tok(out_text, add_special_tokens=False)["input_ids"]
|
||||
assert len(out_tokens) < 450 # ~Whisper max output len
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
("fixture_name", "expected_lang", "expected_text"),
|
||||
[
|
||||
("mary_had_lamb", "en", ["Mary had a little lamb"]),
|
||||
("foscolo", "it", ["zacinto", "sacre"]),
|
||||
],
|
||||
ids=["english", "italian"],
|
||||
)
|
||||
async def test_language_auto_detect(
|
||||
whisper_client, fixture_name, expected_lang, expected_text, request
|
||||
):
|
||||
"""Auto-detect language when no language param is provided."""
|
||||
audio_file = request.getfixturevalue(fixture_name)
|
||||
transcription = await whisper_client.audio.transcriptions.create(
|
||||
model=MODEL_NAME,
|
||||
file=audio_file,
|
||||
response_format="verbose_json",
|
||||
temperature=0.0,
|
||||
)
|
||||
assert transcription.language == expected_lang
|
||||
text_lower = transcription.text.lower()
|
||||
assert any(word.lower() in text_lower for word in expected_text), (
|
||||
f"Expected {expected_lang} text but got: {transcription.text}"
|
||||
)
|
||||
|
||||
@@ -111,6 +111,47 @@ def check_model_available(model: str) -> None:
|
||||
model_info.check_transformers_version(on_fail="skip")
|
||||
|
||||
|
||||
def test_parse_language_detection_output():
|
||||
"""Unit test for WhisperForConditionalGeneration.parse_language_detection_output.
|
||||
|
||||
No GPU or model loading required.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from vllm.model_executor.models.whisper import (
|
||||
WhisperForConditionalGeneration,
|
||||
)
|
||||
|
||||
cls = WhisperForConditionalGeneration
|
||||
|
||||
def make_tokenizer(return_value: str) -> MagicMock:
|
||||
tok = MagicMock()
|
||||
tok.decode = MagicMock(return_value=return_value)
|
||||
return tok
|
||||
|
||||
# English
|
||||
assert (
|
||||
cls.parse_language_detection_output([50259], make_tokenizer("<|en|>")) == "en"
|
||||
)
|
||||
|
||||
# German
|
||||
assert (
|
||||
cls.parse_language_detection_output([50261], make_tokenizer("<|de|>")) == "de"
|
||||
)
|
||||
|
||||
# Unsupported language code
|
||||
with pytest.raises(AssertionError):
|
||||
cls.parse_language_detection_output([99999], make_tokenizer("<|xx|>"))
|
||||
|
||||
# No special token format
|
||||
with pytest.raises(AssertionError):
|
||||
cls.parse_language_detection_output([1], make_tokenizer("hello"))
|
||||
|
||||
# Empty token_ids
|
||||
with pytest.raises((AssertionError, IndexError)):
|
||||
cls.parse_language_detection_output([], make_tokenizer("anything"))
|
||||
|
||||
|
||||
@pytest.mark.core_model
|
||||
@pytest.mark.cpu_model
|
||||
@pytest.mark.parametrize("model", ["openai/whisper-large-v3-turbo"])
|
||||
|
||||
@@ -41,7 +41,10 @@ from vllm.exceptions import VLLMValidationError
|
||||
from vllm.inputs import ProcessorInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logprobs import FlatLogprobs, Logprob
|
||||
from vllm.model_executor.models import SupportsTranscription, supports_transcription
|
||||
from vllm.model_executor.models import (
|
||||
SupportsTranscription,
|
||||
supports_transcription,
|
||||
)
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.renderers.inputs import DictPrompt, EncoderDecoderDictPrompt
|
||||
from vllm.renderers.inputs.preprocess import parse_enc_dec_prompt, parse_model_prompt
|
||||
@@ -242,10 +245,57 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
model_cls = get_model_cls(self.model_config)
|
||||
return cast(type[SupportsTranscription], model_cls)
|
||||
|
||||
async def _detect_language(
|
||||
self,
|
||||
audio_chunk: np.ndarray,
|
||||
request_id: str,
|
||||
) -> str:
|
||||
"""Auto-detect the spoken language from an audio chunk.
|
||||
|
||||
Delegates prompt construction and output parsing to the model class
|
||||
via ``get_language_detection_prompt`` and
|
||||
``parse_language_detection_output``.
|
||||
"""
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
prompt = self.model_cls.get_language_detection_prompt(
|
||||
audio_chunk,
|
||||
self.asr_config,
|
||||
)
|
||||
allowed_token_ids = self.model_cls.get_language_token_ids(
|
||||
self.tokenizer,
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=1,
|
||||
temperature=0.0,
|
||||
allowed_token_ids=allowed_token_ids,
|
||||
)
|
||||
|
||||
result_generator = self.engine_client.generate(
|
||||
prompt,
|
||||
sampling_params,
|
||||
request_id,
|
||||
)
|
||||
|
||||
final_output: RequestOutput
|
||||
async for final_output in result_generator:
|
||||
if final_output.finished:
|
||||
break
|
||||
|
||||
token_ids = list(final_output.outputs[0].token_ids)
|
||||
lang = self.model_cls.parse_language_detection_output(
|
||||
token_ids,
|
||||
self.tokenizer,
|
||||
)
|
||||
|
||||
logger.info("Auto-detected language: '%s'", lang)
|
||||
return lang
|
||||
|
||||
async def _preprocess_speech_to_text(
|
||||
self,
|
||||
request: SpeechToTextRequest,
|
||||
audio_data: bytes,
|
||||
request_id: str,
|
||||
) -> tuple[list[ProcessorInputs], float]:
|
||||
# Validate request
|
||||
language = self.model_cls.validate_language(request.language)
|
||||
@@ -274,6 +324,15 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
and duration > self.asr_config.max_audio_clip_s
|
||||
)
|
||||
chunks = [y] if not do_split_audio else self._split_audio(y, int(sr))
|
||||
|
||||
if language is None and getattr(
|
||||
self.model_cls, "supports_explicit_language_detection", False
|
||||
):
|
||||
language = await self._detect_language(
|
||||
chunks[0], f"{request_id}-lang_detect"
|
||||
)
|
||||
request.language = language
|
||||
|
||||
parsed_prompts: list[DictPrompt] = []
|
||||
for chunk in chunks:
|
||||
# The model has control over the construction, as long as it
|
||||
@@ -435,6 +494,7 @@ class OpenAISpeechToText(OpenAIServing):
|
||||
engine_prompts, duration_s = await self._preprocess_speech_to_text(
|
||||
request=request,
|
||||
audio_data=audio_data,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
|
||||
@@ -1111,6 +1111,16 @@ class SupportsTranscription(Protocol):
|
||||
Enables the segment timestamp option for supported models by setting this to `True`.
|
||||
"""
|
||||
|
||||
supports_explicit_language_detection: ClassVar[bool] = False
|
||||
"""
|
||||
Transcription models that require an explicit language detection step
|
||||
(e.g. Whisper needs a separate forward pass to predict the language
|
||||
token) should set this to ``True`` and implement
|
||||
:meth:`get_language_detection_prompt` and
|
||||
:meth:`parse_language_detection_output` and
|
||||
:meth:`get_language_token_ids`.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
# language codes in supported_languages
|
||||
@@ -1206,6 +1216,46 @@ class SupportsTranscription(Protocol):
|
||||
"""
|
||||
return text
|
||||
|
||||
@classmethod
|
||||
def get_language_detection_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
) -> PromptType:
|
||||
"""Return a prompt that triggers language detection.
|
||||
|
||||
Only needs to be implemented when
|
||||
``supports_explicit_language_detection`` is ``True``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def parse_language_detection_output(
|
||||
cls,
|
||||
token_ids: list[int],
|
||||
tokenizer: object,
|
||||
) -> str:
|
||||
"""Parse the detected language from model output token IDs.
|
||||
|
||||
Only needs to be implemented when
|
||||
``supports_explicit_language_detection`` is ``True``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_language_token_ids(
|
||||
cls,
|
||||
tokenizer: object,
|
||||
) -> list[int] | None:
|
||||
"""Return token IDs that represent valid language tokens.
|
||||
|
||||
Used to constrain language detection to only produce valid language tokens.
|
||||
|
||||
Only needs to be implemented when
|
||||
``supports_explicit_language_detection`` is ``True``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@overload
|
||||
def supports_transcription(
|
||||
|
||||
@@ -64,7 +64,11 @@ from vllm.v1.attention.backend import (
|
||||
AttentionType,
|
||||
)
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
|
||||
from .interfaces import (
|
||||
MultiModalEmbeddings,
|
||||
SupportsMultiModal,
|
||||
SupportsTranscription,
|
||||
)
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
WeightsMapper,
|
||||
@@ -784,7 +788,9 @@ class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo
|
||||
dummy_inputs=WhisperDummyInputsBuilder,
|
||||
)
|
||||
class WhisperForConditionalGeneration(
|
||||
nn.Module, SupportsTranscription, SupportsMultiModal
|
||||
nn.Module,
|
||||
SupportsTranscription,
|
||||
SupportsMultiModal,
|
||||
):
|
||||
packed_modules_mapping = {
|
||||
"self_attn.qkv_proj": [
|
||||
@@ -802,20 +808,18 @@ class WhisperForConditionalGeneration(
|
||||
# Whisper only supports audio-conditioned generation.
|
||||
supports_transcription_only = True
|
||||
supports_segment_timestamp = True
|
||||
supports_explicit_language_detection = True
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str | None) -> str | None:
|
||||
if language is None:
|
||||
# TODO language should be optional and can be guessed.
|
||||
# For now we default to en. See
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||
logger.warning(
|
||||
"Defaulting to language='en'. If you wish to transcribe "
|
||||
"audio in a different language, pass the `language` field "
|
||||
logger.debug(
|
||||
"No language specified. Language will be auto-detected "
|
||||
"from audio. To skip detection, pass the `language` field "
|
||||
"in the TranscriptionRequest."
|
||||
)
|
||||
language = "en"
|
||||
return None
|
||||
return super().validate_language(language)
|
||||
|
||||
@classmethod
|
||||
@@ -846,6 +850,63 @@ class WhisperForConditionalGeneration(
|
||||
decoder_prompt=TextPrompt(prompt=decoder_text),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_language_token_ids(
|
||||
cls,
|
||||
tokenizer: object,
|
||||
) -> list[int]:
|
||||
"""Return token IDs for all supported language tokens.
|
||||
|
||||
Used with ``SamplingParams.allowed_token_ids`` to constrain
|
||||
language detection to only produce valid language tokens.
|
||||
"""
|
||||
token_ids = [
|
||||
tokenizer.convert_tokens_to_ids(f"<|{lang_code}|>")
|
||||
for lang_code in cls.supported_languages
|
||||
]
|
||||
return token_ids
|
||||
|
||||
@classmethod
|
||||
def get_language_detection_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig,
|
||||
) -> PromptType:
|
||||
"""Return a prompt that elicits a single language token from Whisper.
|
||||
|
||||
Feed only ``<|startoftranscript|>`` as the decoder input so the model
|
||||
predicts the most likely language token (e.g. ``<|de|>``).
|
||||
"""
|
||||
return ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(
|
||||
prompt="",
|
||||
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
|
||||
),
|
||||
decoder_prompt=TextPrompt(prompt="<|startoftranscript|>"),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def parse_language_detection_output(
|
||||
cls,
|
||||
token_ids: list[int],
|
||||
tokenizer: object,
|
||||
) -> str | None:
|
||||
"""Parse the language token predicted by Whisper.
|
||||
|
||||
Decodes the first token ID and extracts the language code from the
|
||||
``<|xx|>`` format. Expects a valid language token from constrained generation.
|
||||
"""
|
||||
|
||||
decoded = tokenizer.decode(
|
||||
[token_ids[0]],
|
||||
skip_special_tokens=False,
|
||||
)
|
||||
# Whisper language tokens have the form <|xx|>
|
||||
assert decoded.startswith("<|") and decoded.endswith("|>")
|
||||
lang_code = decoded[2:-2]
|
||||
assert lang_code in cls.supported_languages
|
||||
return lang_code
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("audio"):
|
||||
|
||||
Reference in New Issue
Block a user