diff --git a/vllm/entrypoints/openai/translations/speech_to_text.py b/vllm/entrypoints/openai/translations/speech_to_text.py index c993e1ebd..e80d6b9a2 100644 --- a/vllm/entrypoints/openai/translations/speech_to_text.py +++ b/vllm/entrypoints/openai/translations/speech_to_text.py @@ -37,7 +37,8 @@ from vllm.entrypoints.openai.translations.protocol import ( TranslationStreamResponse, ) from vllm.exceptions import VLLMValidationError -from vllm.inputs.data import PromptType +from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType +from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger from vllm.logprobs import FlatLogprobs, Logprob from vllm.model_executor.models import SupportsTranscription, supports_transcription @@ -296,26 +297,37 @@ class OpenAISpeechToText(OpenAIServing): to_language=to_language, ) if request.response_format == "verbose_json": - if not isinstance(prompt, dict): + if not is_explicit_encoder_decoder_prompt(prompt): raise VLLMValidationError( - "Expected prompt to be a dict", + "Expected prompt to be an encoder-decoder prompt", parameter="prompt", value=type(prompt).__name__, ) - prompt_dict = cast(dict, prompt) - decoder_prompt = prompt.get("decoder_prompt") - if not isinstance(decoder_prompt, str): - raise VLLMValidationError( - "Expected decoder_prompt to be str", - parameter="decoder_prompt", - value=type(decoder_prompt).__name__, - ) - prompt_dict["decoder_prompt"] = decoder_prompt.replace( - "<|notimestamps|>", "<|0.00|>" - ) + + prompt = self._preprocess_verbose_prompt(prompt) + prompts.append(prompt) return prompts, duration + def _repl_verbose_text(self, text: str): + return text.replace("<|notimestamps|>", "<|0.00|>") + + def _preprocess_verbose_prompt(self, prompt: ExplicitEncoderDecoderPrompt): + dec_prompt = prompt["decoder_prompt"] + + if isinstance(dec_prompt, str): + prompt["decoder_prompt"] = self._repl_verbose_text(dec_prompt) + elif isinstance(dec_prompt, dict) and "prompt" in dec_prompt: + dec_prompt["prompt"] = self._repl_verbose_text(dec_prompt["prompt"]) + else: + raise VLLMValidationError( + "Expected decoder_prompt to contain text", + parameter="decoder_prompt", + value=type(dec_prompt).__name__, + ) + + return prompt + def _get_verbose_segments( self, tokens: tuple, diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 4b39877bb..1460a4586 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, cast +from typing import Annotated, Any, Literal import numpy as np import torch @@ -19,7 +19,7 @@ from transformers.models.siglip import SiglipImageProcessorFast from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TextPrompt from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import RowParallelLinear @@ -807,9 +807,10 @@ class Gemma3nForConditionalGeneration( prompt += ": \nmodel\n" - audio = (audio, stt_config.sample_rate) - prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt} - return cast(PromptType, prompts_dict) + return TextPrompt( + prompt=prompt, + multi_modal_data={"audio": (audio, stt_config.sample_rate)}, + ) @classmethod def get_speech_to_text_config( diff --git a/vllm/model_executor/models/glmasr.py b/vllm/model_executor/models/glmasr.py index 2651540d2..b9bdb3aa2 100644 --- a/vllm/model_executor/models/glmasr.py +++ b/vllm/model_executor/models/glmasr.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable, Mapping, Sequence -from typing import Annotated, Any, Literal, TypeAlias, cast +from typing import Annotated, Any, Literal, TypeAlias import numpy as np import torch @@ -14,7 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TokensPrompt from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import MMEncoderAttention from vllm.model_executor.layers.linear import ( @@ -1159,8 +1159,8 @@ class GlmAsrForConditionalGeneration( ) prompt_token_ids = tokenizer.encode(prompt) - prompt_dict = { - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": {"audio": audio}, - } - return cast(PromptType, prompt_dict) + + return TokensPrompt( + prompt_token_ids=prompt_token_ids, + multi_modal_data={"audio": audio}, + ) diff --git a/vllm/model_executor/models/granite_speech.py b/vllm/model_executor/models/granite_speech.py index 1f9b9d2c8..6956f92ee 100644 --- a/vllm/model_executor/models/granite_speech.py +++ b/vllm/model_executor/models/granite_speech.py @@ -26,7 +26,7 @@ import math from collections.abc import Iterable, Mapping -from typing import Annotated, Literal, cast +from typing import Annotated, Literal import numpy as np import torch @@ -36,7 +36,7 @@ from transformers import BatchFeature, PretrainedConfig from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TokensPrompt from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -879,11 +879,11 @@ class GraniteSpeechForConditionalGeneration( ) prompt_token_ids = tokenizer.encode(prompt) - prompt = { - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": {"audio": audio}, - } - return cast(PromptType, prompt) + + return TokensPrompt( + prompt_token_ids=prompt_token_ids, + multi_modal_data={"audio": audio}, + ) # Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501 @classmethod diff --git a/vllm/model_executor/models/qwen3_asr.py b/vllm/model_executor/models/qwen3_asr.py index 605ccee48..b27f710db 100644 --- a/vllm/model_executor/models/qwen3_asr.py +++ b/vllm/model_executor/models/qwen3_asr.py @@ -23,7 +23,7 @@ """Inference-only Qwen3-ASR model.""" from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, cast +from typing import Any, Literal import numpy as np import torch @@ -33,7 +33,7 @@ from transformers.models.whisper import WhisperFeatureExtractor from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.models.interfaces import ( MultiModalEmbeddings, @@ -561,11 +561,11 @@ class Qwen3ASRForConditionalGeneration( ) prompt_token_ids = tokenizer.encode(prompt) - prompt_dict = { - "prompt_token_ids": prompt_token_ids, - "multi_modal_data": {"audio": audio}, - } - return cast(PromptType, prompt_dict) + + return TokensPrompt( + prompt_token_ids=prompt_token_ids, + multi_modal_data={"audio": audio}, + ) @classmethod def post_process_output(cls, text: str) -> str: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index c187dba14..2fc987682 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -25,7 +25,7 @@ from transformers.tokenization_utils_base import TextInput from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions -from vllm.inputs.data import PromptType +from vllm.inputs.data import PromptType, TokensPrompt from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -488,10 +488,13 @@ class VoxtralForConditionalGeneration( ) tokenized = tokenizer.instruct.encode_transcription(req) - audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) - prompts_dict = {"multi_modal_data": {"audio": audio}} - prompts_dict["prompt_token_ids"] = tokenized.tokens - return cast(PromptType, prompts_dict) + + return TokensPrompt( + prompt_token_ids=tokenized.tokens, + multi_modal_data={ + "audio": (tokenized.audios[0].audio_array, stt_config.sample_rate) + }, + ) @classmethod def get_num_audio_tokens( diff --git a/vllm/model_executor/models/voxtral_realtime.py b/vllm/model_executor/models/voxtral_realtime.py index cbd3f73ae..82801b6eb 100644 --- a/vllm/model_executor/models/voxtral_realtime.py +++ b/vllm/model_executor/models/voxtral_realtime.py @@ -4,7 +4,7 @@ import asyncio import math from collections.abc import AsyncGenerator, Mapping -from typing import Literal, cast +from typing import Literal import numpy as np import torch @@ -453,7 +453,10 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim ) tokenized = tokenizer.instruct.encode_transcription(req) - audio = (tokenized.audios[0].audio_array, stt_config.sample_rate) - prompts_dict = {"multi_modal_data": {"audio": audio}} - prompts_dict["prompt_token_ids"] = tokenized.tokens - return cast(PromptType, prompts_dict) + + return TokensPrompt( + prompt_token_ids=tokenized.tokens, + multi_modal_data={ + "audio": (tokenized.audios[0].audio_array, stt_config.sample_rate) + }, + ) diff --git a/vllm/model_executor/models/whisper.py b/vllm/model_executor/models/whisper.py index 58d24d0c9..f62bffada 100644 --- a/vllm/model_executor/models/whisper.py +++ b/vllm/model_executor/models/whisper.py @@ -5,7 +5,7 @@ import enum import math from collections.abc import Iterable, Mapping, Sequence from contextlib import nullcontext -from typing import Annotated, Literal, cast +from typing import Annotated, Literal import numpy as np import torch @@ -21,7 +21,7 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.inputs.data import PromptType +from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import ( @@ -815,21 +815,18 @@ class WhisperForConditionalGeneration( raise ValueError( "Language must be specified when creating the Whisper prompt" ) - prompt = { - "encoder_prompt": { - # Whisper does not support encoder prompt. - "prompt": "", - "multi_modal_data": { - "audio": (audio, stt_config.sample_rate), - }, - }, - "decoder_prompt": ( - (f"<|prev|>{request_prompt}" if request_prompt else "") - + f"<|startoftranscript|><|{language}|>" - + f"<|{task_type}|><|notimestamps|>" + + decoder_text = ( + f"<|prev|>{request_prompt}" if request_prompt else "" + ) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>" + + return ExplicitEncoderDecoderPrompt( + encoder_prompt=TextPrompt( + prompt="", # Whisper does not support encoder prompt. + multi_modal_data={"audio": (audio, stt_config.sample_rate)}, ), - } - return cast(PromptType, prompt) + decoder_prompt=TextPrompt(prompt=decoder_text), + ) @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: