[Bugfix] Relax lang pin for voxtral (#21833)

Signed-off-by: Sanchit Gandhi <sgandhi3141@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Sanchit Gandhi
2025-07-31 04:38:52 +01:00
committed by GitHub
parent 9cb497bfa3
commit ec02e536df
4 changed files with 80 additions and 80 deletions

View File

@@ -26,8 +26,7 @@ from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
# yapf: disable
from vllm.model_executor.models.whisper import (
WhisperEncoder, WhisperForConditionalGeneration)
from vllm.model_executor.models.whisper import WhisperEncoder
# yapf: enable
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -50,6 +49,18 @@ from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
logger = init_logger(__name__)
ISO639_1_SUPPORTED_LANGS = {
"ar": "Arabic",
"nl": "Dutch",
"en": "English",
"fr": "French",
"de": "German",
"hi": "Hindi",
"it": "Italian",
"pt": "Portuguese",
"es": "Spanish",
}
class VoxtralProcessorAdapter:
"""
@@ -301,6 +312,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
dummy_inputs=VoxtralDummyInputsBuilder)
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsTranscription):
supported_languages = ISO639_1_SUPPORTED_LANGS
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -441,8 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# for speech-to-text transcription
def get_generation_prompt(cls, audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
stt_config: SpeechToTextConfig,
language: Optional[str], task_type: str,
request_prompt: str) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate),
@@ -457,11 +469,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
@classmethod
def validate_language(cls, language: str) -> bool:
# same as whisper
return WhisperForConditionalGeneration.validate_language(language)
@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,