[Bugfix] Relax lang pin for voxtral (#21833)
Signed-off-by: Sanchit Gandhi <sgandhi3141@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -26,8 +26,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.whisper import (
|
||||
WhisperEncoder, WhisperForConditionalGeneration)
|
||||
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||
# yapf: enable
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@@ -50,6 +49,18 @@ from .utils import (flatten_bn, init_vllm_registered_model, maybe_prefix,
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
ISO639_1_SUPPORTED_LANGS = {
|
||||
"ar": "Arabic",
|
||||
"nl": "Dutch",
|
||||
"en": "English",
|
||||
"fr": "French",
|
||||
"de": "German",
|
||||
"hi": "Hindi",
|
||||
"it": "Italian",
|
||||
"pt": "Portuguese",
|
||||
"es": "Spanish",
|
||||
}
|
||||
|
||||
|
||||
class VoxtralProcessorAdapter:
|
||||
"""
|
||||
@@ -301,6 +312,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
dummy_inputs=VoxtralDummyInputsBuilder)
|
||||
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsTranscription):
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
@@ -441,8 +453,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# for speech-to-text transcription
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str], task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio = Audio(audio, int(stt_config.sample_rate),
|
||||
@@ -457,11 +469,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
prompts_dict["prompt_token_ids"] = tokenized.tokens
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@classmethod
|
||||
def validate_language(cls, language: str) -> bool:
|
||||
# same as whisper
|
||||
return WhisperForConditionalGeneration.validate_language(language)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
|
||||
Reference in New Issue
Block a user