[Bugfix] Relax lang pin for voxtral (#21833)

Signed-off-by: Sanchit Gandhi <sgandhi3141@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi
2025-07-31 04:38:52 +01:00
committed by GitHub
parent 9cb497bfa3
commit ec02e536df
4 changed files with 80 additions and 80 deletions

View File

@@ -86,11 +86,7 @@ class OpenAISpeechToText(OpenAIServing):
audio_data: bytes, audio_data: bytes,
) -> tuple[list[PromptType], float]: ) -> tuple[list[PromptType], float]:
# Validate request # Validate request
# TODO language should be optional and can be guessed. language = self.model_cls.validate_language(request.language)
# For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
lang = request.language or "en"
self.model_cls.validate_language(lang)
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb: if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
raise ValueError("Maximum file size exceeded.") raise ValueError("Maximum file size exceeded.")
@@ -112,7 +108,7 @@ class OpenAISpeechToText(OpenAIServing):
audio=chunk, audio=chunk,
stt_config=self.asr_config, stt_config=self.asr_config,
model_config=self.model_config, model_config=self.model_config,
language=lang, language=language,
task_type=self.task_type, task_type=self.task_type,
request_prompt=request.prompt) request_prompt=request.prompt)
prompts.append(prompt) prompts.append(prompt)

View File

@@ -1,13 +1,14 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, MutableSequence from collections.abc import Iterable, Mapping, MutableSequence
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol, from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
Union, overload, runtime_checkable) Union, overload, runtime_checkable)
import numpy as np import numpy as np
import torch import torch
from torch import Tensor from torch import Tensor
from transformers.models.whisper.tokenization_whisper import LANGUAGES
from typing_extensions import Self, TypeIs from typing_extensions import Self, TypeIs
from vllm.config import ModelConfig, SpeechToTextConfig from vllm.config import ModelConfig, SpeechToTextConfig
@@ -685,6 +686,8 @@ class SupportsQuant:
@runtime_checkable @runtime_checkable
class SupportsTranscription(Protocol): class SupportsTranscription(Protocol):
"""The interface required for all models that support transcription.""" """The interface required for all models that support transcription."""
# Mapping from ISO639_1 language codes: language names
supported_languages: ClassVar[Mapping[str, str]]
supports_transcription: ClassVar[Literal[True]] = True supports_transcription: ClassVar[Literal[True]] = True
@@ -694,11 +697,22 @@ class SupportsTranscription(Protocol):
`True`. `True`.
""" """
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
# language codes in supported_languages
# that don't exist in the full language map
invalid = set(cls.supported_languages) - set(LANGUAGES.keys())
if invalid:
raise ValueError(
f"{cls.__name__}.supported_languages contains invalid "
f"language codes: {sorted(invalid)}\n. "
f"Valid choices are: {sorted(LANGUAGES.keys())}")
@classmethod @classmethod
def get_generation_prompt(cls, audio: np.ndarray, def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
model_config: ModelConfig, language: str, model_config: ModelConfig,
task_type: str, language: Optional[str], task_type: str,
request_prompt: str) -> PromptType: request_prompt: str) -> PromptType:
"""Get the prompt for the ASR model. """Get the prompt for the ASR model.
The model has control over the construction, as long as it The model has control over the construction, as long as it
@@ -706,9 +720,36 @@ class SupportsTranscription(Protocol):
... ...
@classmethod @classmethod
def validate_language(cls, language: str) -> bool: def get_other_languages(cls) -> Mapping[str, str]:
"""Check if the model supports a specific ISO639_1 language.""" # other possible language codes from the whisper map
... return {
k: v
for k, v in LANGUAGES.items() if k not in cls.supported_languages
}
@classmethod
def validate_language(cls, language: Optional[str]) -> Optional[str]:
"""
Ensure the language specified in the transcription request
is a valid ISO 639-1 language code. If the request language is
valid, but not natively supported by the model, trigger a
warning (but not an exception).
"""
if language is None or language in cls.supported_languages:
return language
elif language in cls.get_other_languages():
logger.warning(
"Language %r is not natively supported by %s; "
"results may be less accurate. Supported languages: %r",
language,
cls.__name__,
list(cls.supported_languages.keys()),
)
return language
else:
raise ValueError(
f"Unsupported language: {language!r}. Must be one of "
f"{list(cls.supported_languages.keys())}.")
@classmethod @classmethod
def get_speech_to_text_config( def get_speech_to_text_config(

View File

@@ -26,8 +26,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP from vllm.model_executor.models import SupportsPP
# yapf: disable # yapf: disable
from vllm.model_executor.models.whisper import ( from vllm.model_executor.models.whisper import WhisperEncoder
WhisperEncoder, WhisperForConditionalGeneration)
# yapf: enable # yapf: enable
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -50,6 +49,18 @@ from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
logger = init_logger(__name__) logger = init_logger(__name__)
ISO639_1_SUPPORTED_LANGS = {
"ar": "Arabic",
"nl": "Dutch",
"en": "English",
"fr": "French",
"de": "German",
"hi": "Hindi",
"it": "Italian",
"pt": "Portuguese",
"es": "Spanish",
}
class VoxtralProcessorAdapter: class VoxtralProcessorAdapter:
""" """
@@ -301,6 +312,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
dummy_inputs=VoxtralDummyInputsBuilder) dummy_inputs=VoxtralDummyInputsBuilder)
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal, class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsTranscription): SupportsPP, SupportsTranscription):
supported_languages = ISO639_1_SUPPORTED_LANGS
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__() super().__init__()
@@ -441,8 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# for speech-to-text transcription # for speech-to-text transcription
def get_generation_prompt(cls, audio: np.ndarray, def get_generation_prompt(cls, audio: np.ndarray,
model_config: ModelConfig, model_config: ModelConfig,
stt_config: SpeechToTextConfig, language: str, stt_config: SpeechToTextConfig,
task_type: str, language: Optional[str], task_type: str,
request_prompt: str) -> PromptType: request_prompt: str) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config) tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate), audio = Audio(audio, int(stt_config.sample_rate),
@@ -457,11 +469,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
prompts_dict["prompt_token_ids"] = tokenized.tokens prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict) return cast(PromptType, prompts_dict)
@classmethod
def validate_language(cls, language: str) -> bool:
# same as whisper
return WhisperForConditionalGeneration.validate_language(language)
@classmethod @classmethod
def get_num_audio_tokens(cls, audio_duration_s: float, def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,

View File

@@ -109,51 +109,6 @@ ISO639_1_SUPPORTED_LANGS = {
"vi": "Vietnamese", "vi": "Vietnamese",
"cy": "Welsh" "cy": "Welsh"
} }
ISO639_1_OTHER_LANGS = {
"lo": "Lao",
"jw": "Javanese",
"tk": "Turkmen",
"yi": "Yiddish",
"so": "Somali",
"bn": "Bengali",
"nn": "Norwegian Nynorsk",
"si": "Sinhala",
"yo": "Yoruba",
"sa": "Sanskrit",
"mi": "Māori",
"fo": "Faroese", # codespell:ignore
"mt": "Maltese",
"tg": "Tajik",
"mg": "Malagasy",
"haw": "Hawaiian",
"km": "Khmer",
"br": "Breton",
"ps": "Pashto",
"ln": "Lingala",
"la": "Latin",
"ml": "Malayalam",
"sq": "Albanian",
"su": "Sundanese",
"eu": "Basque",
"ka": "Georgian",
"uz": "Uzbek",
"sn": "Shona",
"ht": "Haitian",
"as": "Assamese",
"mn": "Mongolian",
"te": "Telugu",
"pa": "Panjabi",
"tt": "Tatar",
"gu": "Gujarati",
"oc": "Occitan",
"ha": "Hausa",
"ba": "Bashkir",
"my": "Burmese",
"sd": "Sindhi",
"am": "Amharic",
"lb": "Luxembourgish",
"bo": "Tibetan"
}
class WhisperAudioInputs(TypedDict): class WhisperAudioInputs(TypedDict):
@@ -807,22 +762,20 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
# Whisper only supports audio-conditioned generation. # Whisper only supports audio-conditioned generation.
supports_transcription_only = True supports_transcription_only = True
supported_languages = ISO639_1_SUPPORTED_LANGS
@classmethod @classmethod
def validate_language(cls, language: str) -> bool: def validate_language(cls, language: Optional[str]) -> Optional[str]:
if language in ISO639_1_SUPPORTED_LANGS: if language is None:
return True # TODO language should be optional and can be guessed.
elif language in ISO639_1_OTHER_LANGS: # For now we default to en. See
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
logger.warning( logger.warning(
"The selected language %s has limited accuracy with" "Defaulting to language='en'. If you wish to transcribe "
" reported WER>=0.5. Results may be less accurate " "audio in a different language, pass the `language` field "
"for this choice.", language) "in the TranscriptionRequest.")
return True language = "en"
else: return super().validate_language(language)
raise ValueError(f"Unsupported language: {language}."
"Language should be one of:" +
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
f"or {list(ISO639_1_OTHER_LANGS.values())}")
@classmethod @classmethod
def get_generation_prompt( def get_generation_prompt(
@@ -830,9 +783,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
audio: np.ndarray, audio: np.ndarray,
model_config: ModelConfig, # not needed here model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig, stt_config: SpeechToTextConfig,
language: str, language: Optional[str],
task_type: str, task_type: str,
request_prompt: str) -> PromptType: request_prompt: str) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the Whisper prompt")
prompt = { prompt = {
"encoder_prompt": { "encoder_prompt": {
# Whisper does not support encoder prompt. # Whisper does not support encoder prompt.