[Bugfix] Relax lang pin for voxtral (#21833)
Signed-off-by: Sanchit Gandhi <sgandhi3141@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -86,11 +86,7 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
audio_data: bytes,
|
audio_data: bytes,
|
||||||
) -> tuple[list[PromptType], float]:
|
) -> tuple[list[PromptType], float]:
|
||||||
# Validate request
|
# Validate request
|
||||||
# TODO language should be optional and can be guessed.
|
language = self.model_cls.validate_language(request.language)
|
||||||
# For now we default to en. See
|
|
||||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
|
||||||
lang = request.language or "en"
|
|
||||||
self.model_cls.validate_language(lang)
|
|
||||||
|
|
||||||
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
if len(audio_data) / 1024**2 > self.max_audio_filesize_mb:
|
||||||
raise ValueError("Maximum file size exceeded.")
|
raise ValueError("Maximum file size exceeded.")
|
||||||
@@ -112,7 +108,7 @@ class OpenAISpeechToText(OpenAIServing):
|
|||||||
audio=chunk,
|
audio=chunk,
|
||||||
stt_config=self.asr_config,
|
stt_config=self.asr_config,
|
||||||
model_config=self.model_config,
|
model_config=self.model_config,
|
||||||
language=lang,
|
language=language,
|
||||||
task_type=self.task_type,
|
task_type=self.task_type,
|
||||||
request_prompt=request.prompt)
|
request_prompt=request.prompt)
|
||||||
prompts.append(prompt)
|
prompts.append(prompt)
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from collections.abc import Iterable, MutableSequence
|
from collections.abc import Iterable, Mapping, MutableSequence
|
||||||
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
from typing import (TYPE_CHECKING, ClassVar, Literal, Optional, Protocol,
|
||||||
Union, overload, runtime_checkable)
|
Union, overload, runtime_checkable)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
from transformers.models.whisper.tokenization_whisper import LANGUAGES
|
||||||
from typing_extensions import Self, TypeIs
|
from typing_extensions import Self, TypeIs
|
||||||
|
|
||||||
from vllm.config import ModelConfig, SpeechToTextConfig
|
from vllm.config import ModelConfig, SpeechToTextConfig
|
||||||
@@ -685,6 +686,8 @@ class SupportsQuant:
|
|||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class SupportsTranscription(Protocol):
|
class SupportsTranscription(Protocol):
|
||||||
"""The interface required for all models that support transcription."""
|
"""The interface required for all models that support transcription."""
|
||||||
|
# Mapping from ISO639_1 language codes: language names
|
||||||
|
supported_languages: ClassVar[Mapping[str, str]]
|
||||||
|
|
||||||
supports_transcription: ClassVar[Literal[True]] = True
|
supports_transcription: ClassVar[Literal[True]] = True
|
||||||
|
|
||||||
@@ -694,11 +697,22 @@ class SupportsTranscription(Protocol):
|
|||||||
`True`.
|
`True`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
# language codes in supported_languages
|
||||||
|
# that don't exist in the full language map
|
||||||
|
invalid = set(cls.supported_languages) - set(LANGUAGES.keys())
|
||||||
|
if invalid:
|
||||||
|
raise ValueError(
|
||||||
|
f"{cls.__name__}.supported_languages contains invalid "
|
||||||
|
f"language codes: {sorted(invalid)}\n. "
|
||||||
|
f"Valid choices are: {sorted(LANGUAGES.keys())}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_generation_prompt(cls, audio: np.ndarray,
|
def get_generation_prompt(cls, audio: np.ndarray,
|
||||||
stt_config: SpeechToTextConfig,
|
stt_config: SpeechToTextConfig,
|
||||||
model_config: ModelConfig, language: str,
|
model_config: ModelConfig,
|
||||||
task_type: str,
|
language: Optional[str], task_type: str,
|
||||||
request_prompt: str) -> PromptType:
|
request_prompt: str) -> PromptType:
|
||||||
"""Get the prompt for the ASR model.
|
"""Get the prompt for the ASR model.
|
||||||
The model has control over the construction, as long as it
|
The model has control over the construction, as long as it
|
||||||
@@ -706,9 +720,36 @@ class SupportsTranscription(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_language(cls, language: str) -> bool:
|
def get_other_languages(cls) -> Mapping[str, str]:
|
||||||
"""Check if the model supports a specific ISO639_1 language."""
|
# other possible language codes from the whisper map
|
||||||
...
|
return {
|
||||||
|
k: v
|
||||||
|
for k, v in LANGUAGES.items() if k not in cls.supported_languages
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_language(cls, language: Optional[str]) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Ensure the language specified in the transcription request
|
||||||
|
is a valid ISO 639-1 language code. If the request language is
|
||||||
|
valid, but not natively supported by the model, trigger a
|
||||||
|
warning (but not an exception).
|
||||||
|
"""
|
||||||
|
if language is None or language in cls.supported_languages:
|
||||||
|
return language
|
||||||
|
elif language in cls.get_other_languages():
|
||||||
|
logger.warning(
|
||||||
|
"Language %r is not natively supported by %s; "
|
||||||
|
"results may be less accurate. Supported languages: %r",
|
||||||
|
language,
|
||||||
|
cls.__name__,
|
||||||
|
list(cls.supported_languages.keys()),
|
||||||
|
)
|
||||||
|
return language
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported language: {language!r}. Must be one of "
|
||||||
|
f"{list(cls.supported_languages.keys())}.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_speech_to_text_config(
|
def get_speech_to_text_config(
|
||||||
|
|||||||
@@ -26,8 +26,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models import SupportsPP
|
from vllm.model_executor.models import SupportsPP
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.model_executor.models.whisper import (
|
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||||
WhisperEncoder, WhisperForConditionalGeneration)
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@@ -50,6 +49,18 @@ from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
|||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
ISO639_1_SUPPORTED_LANGS = {
|
||||||
|
"ar": "Arabic",
|
||||||
|
"nl": "Dutch",
|
||||||
|
"en": "English",
|
||||||
|
"fr": "French",
|
||||||
|
"de": "German",
|
||||||
|
"hi": "Hindi",
|
||||||
|
"it": "Italian",
|
||||||
|
"pt": "Portuguese",
|
||||||
|
"es": "Spanish",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class VoxtralProcessorAdapter:
|
class VoxtralProcessorAdapter:
|
||||||
"""
|
"""
|
||||||
@@ -301,6 +312,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
|||||||
dummy_inputs=VoxtralDummyInputsBuilder)
|
dummy_inputs=VoxtralDummyInputsBuilder)
|
||||||
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||||
SupportsPP, SupportsTranscription):
|
SupportsPP, SupportsTranscription):
|
||||||
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -441,8 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
# for speech-to-text transcription
|
# for speech-to-text transcription
|
||||||
def get_generation_prompt(cls, audio: np.ndarray,
|
def get_generation_prompt(cls, audio: np.ndarray,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
stt_config: SpeechToTextConfig, language: str,
|
stt_config: SpeechToTextConfig,
|
||||||
task_type: str,
|
language: Optional[str], task_type: str,
|
||||||
request_prompt: str) -> PromptType:
|
request_prompt: str) -> PromptType:
|
||||||
tokenizer = cached_tokenizer_from_config(model_config)
|
tokenizer = cached_tokenizer_from_config(model_config)
|
||||||
audio = Audio(audio, int(stt_config.sample_rate),
|
audio = Audio(audio, int(stt_config.sample_rate),
|
||||||
@@ -457,11 +469,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
||||||
return cast(PromptType, prompts_dict)
|
return cast(PromptType, prompts_dict)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate_language(cls, language: str) -> bool:
|
|
||||||
# same as whisper
|
|
||||||
return WhisperForConditionalGeneration.validate_language(language)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||||
stt_config: SpeechToTextConfig,
|
stt_config: SpeechToTextConfig,
|
||||||
|
|||||||
@@ -109,51 +109,6 @@ ISO639_1_SUPPORTED_LANGS = {
|
|||||||
"vi": "Vietnamese",
|
"vi": "Vietnamese",
|
||||||
"cy": "Welsh"
|
"cy": "Welsh"
|
||||||
}
|
}
|
||||||
ISO639_1_OTHER_LANGS = {
|
|
||||||
"lo": "Lao",
|
|
||||||
"jw": "Javanese",
|
|
||||||
"tk": "Turkmen",
|
|
||||||
"yi": "Yiddish",
|
|
||||||
"so": "Somali",
|
|
||||||
"bn": "Bengali",
|
|
||||||
"nn": "Norwegian Nynorsk",
|
|
||||||
"si": "Sinhala",
|
|
||||||
"yo": "Yoruba",
|
|
||||||
"sa": "Sanskrit",
|
|
||||||
"mi": "Māori",
|
|
||||||
"fo": "Faroese", # codespell:ignore
|
|
||||||
"mt": "Maltese",
|
|
||||||
"tg": "Tajik",
|
|
||||||
"mg": "Malagasy",
|
|
||||||
"haw": "Hawaiian",
|
|
||||||
"km": "Khmer",
|
|
||||||
"br": "Breton",
|
|
||||||
"ps": "Pashto",
|
|
||||||
"ln": "Lingala",
|
|
||||||
"la": "Latin",
|
|
||||||
"ml": "Malayalam",
|
|
||||||
"sq": "Albanian",
|
|
||||||
"su": "Sundanese",
|
|
||||||
"eu": "Basque",
|
|
||||||
"ka": "Georgian",
|
|
||||||
"uz": "Uzbek",
|
|
||||||
"sn": "Shona",
|
|
||||||
"ht": "Haitian",
|
|
||||||
"as": "Assamese",
|
|
||||||
"mn": "Mongolian",
|
|
||||||
"te": "Telugu",
|
|
||||||
"pa": "Panjabi",
|
|
||||||
"tt": "Tatar",
|
|
||||||
"gu": "Gujarati",
|
|
||||||
"oc": "Occitan",
|
|
||||||
"ha": "Hausa",
|
|
||||||
"ba": "Bashkir",
|
|
||||||
"my": "Burmese",
|
|
||||||
"sd": "Sindhi",
|
|
||||||
"am": "Amharic",
|
|
||||||
"lb": "Luxembourgish",
|
|
||||||
"bo": "Tibetan"
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class WhisperAudioInputs(TypedDict):
|
class WhisperAudioInputs(TypedDict):
|
||||||
@@ -807,22 +762,20 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
|
|
||||||
# Whisper only supports audio-conditioned generation.
|
# Whisper only supports audio-conditioned generation.
|
||||||
supports_transcription_only = True
|
supports_transcription_only = True
|
||||||
|
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_language(cls, language: str) -> bool:
|
def validate_language(cls, language: Optional[str]) -> Optional[str]:
|
||||||
if language in ISO639_1_SUPPORTED_LANGS:
|
if language is None:
|
||||||
return True
|
# TODO language should be optional and can be guessed.
|
||||||
elif language in ISO639_1_OTHER_LANGS:
|
# For now we default to en. See
|
||||||
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The selected language %s has limited accuracy with"
|
"Defaulting to language='en'. If you wish to transcribe "
|
||||||
" reported WER>=0.5. Results may be less accurate "
|
"audio in a different language, pass the `language` field "
|
||||||
"for this choice.", language)
|
"in the TranscriptionRequest.")
|
||||||
return True
|
language = "en"
|
||||||
else:
|
return super().validate_language(language)
|
||||||
raise ValueError(f"Unsupported language: {language}."
|
|
||||||
"Language should be one of:" +
|
|
||||||
f" {list(ISO639_1_SUPPORTED_LANGS.values())}" +
|
|
||||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_generation_prompt(
|
def get_generation_prompt(
|
||||||
@@ -830,9 +783,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
model_config: ModelConfig, # not needed here
|
model_config: ModelConfig, # not needed here
|
||||||
stt_config: SpeechToTextConfig,
|
stt_config: SpeechToTextConfig,
|
||||||
language: str,
|
language: Optional[str],
|
||||||
task_type: str,
|
task_type: str,
|
||||||
request_prompt: str) -> PromptType:
|
request_prompt: str) -> PromptType:
|
||||||
|
if language is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Language must be specified when creating the Whisper prompt")
|
||||||
prompt = {
|
prompt = {
|
||||||
"encoder_prompt": {
|
"encoder_prompt": {
|
||||||
# Whisper does not support encoder prompt.
|
# Whisper does not support encoder prompt.
|
||||||
|
|||||||
Reference in New Issue
Block a user