[Model] Use explicit types in get_generation_prompt (#33551)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-02 20:38:49 +08:00
committed by GitHub
parent b398e5c819
commit b10d05b8a8
8 changed files with 82 additions and 66 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, cast
from typing import Annotated, Any, Literal
import numpy as np
import torch
@@ -19,7 +19,7 @@ from transformers.models.siglip import SiglipImageProcessorFast
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TextPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
@@ -807,9 +807,10 @@ class Gemma3nForConditionalGeneration(
prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"
audio = (audio, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}, "prompt": prompt}
return cast(PromptType, prompts_dict)
return TextPrompt(
prompt=prompt,
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
)
@classmethod
def get_speech_to_text_config(

View File

@@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Any, Literal, TypeAlias, cast
from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
@@ -14,7 +14,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import MMEncoderAttention
from vllm.model_executor.layers.linear import (
@@ -1159,8 +1159,8 @@ class GlmAsrForConditionalGeneration(
)
prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt_dict)
return TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"audio": audio},
)

View File

@@ -26,7 +26,7 @@
import math
from collections.abc import Iterable, Mapping
from typing import Annotated, Literal, cast
from typing import Annotated, Literal
import numpy as np
import torch
@@ -36,7 +36,7 @@ from transformers import BatchFeature, PretrainedConfig
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.model_executor.layers.linear import ColumnParallelLinear, RowParallelLinear
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -879,11 +879,11 @@ class GraniteSpeechForConditionalGeneration(
)
prompt_token_ids = tokenizer.encode(prompt)
prompt = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt)
return TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"audio": audio},
)
# Adapted from https://github.com/huggingface/transformers/blob/v4.56.0/src/transformers/models/granite_speech/feature_extraction_granite_speech.py#L122 # noqa: E501
@classmethod

View File

@@ -23,7 +23,7 @@
"""Inference-only Qwen3-ASR model."""
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, cast
from typing import Any, Literal
import numpy as np
import torch
@@ -33,7 +33,7 @@ from transformers.models.whisper import WhisperFeatureExtractor
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import (
MultiModalEmbeddings,
@@ -561,11 +561,11 @@ class Qwen3ASRForConditionalGeneration(
)
prompt_token_ids = tokenizer.encode(prompt)
prompt_dict = {
"prompt_token_ids": prompt_token_ids,
"multi_modal_data": {"audio": audio},
}
return cast(PromptType, prompt_dict)
return TokensPrompt(
prompt_token_ids=prompt_token_ids,
multi_modal_data={"audio": audio},
)
@classmethod
def post_process_output(cls, text: str) -> str:

View File

@@ -25,7 +25,7 @@ from transformers.tokenization_utils_base import TextInput
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.inputs.data import PromptType
from vllm.inputs.data import PromptType, TokensPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -488,10 +488,13 @@ class VoxtralForConditionalGeneration(
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
return TokensPrompt(
prompt_token_ids=tokenized.tokens,
multi_modal_data={
"audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
},
)
@classmethod
def get_num_audio_tokens(

View File

@@ -4,7 +4,7 @@
import asyncio
import math
from collections.abc import AsyncGenerator, Mapping
from typing import Literal, cast
from typing import Literal
import numpy as np
import torch
@@ -453,7 +453,10 @@ class VoxtralRealtimeGeneration(VoxtralForConditionalGeneration, SupportsRealtim
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
prompts_dict = {"multi_modal_data": {"audio": audio}}
prompts_dict["prompt_token_ids"] = tokenized.tokens
return cast(PromptType, prompts_dict)
return TokensPrompt(
prompt_token_ids=tokenized.tokens,
multi_modal_data={
"audio": (tokenized.audios[0].audio_array, stt_config.sample_rate)
},
)

View File

@@ -5,7 +5,7 @@ import enum
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Annotated, Literal, cast
from typing import Annotated, Literal
import numpy as np
import torch
@@ -21,7 +21,7 @@ from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (
@@ -815,21 +815,18 @@ class WhisperForConditionalGeneration(
raise ValueError(
"Language must be specified when creating the Whisper prompt"
)
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.
"prompt": "",
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate),
},
},
"decoder_prompt": (
(f"<|prev|>{request_prompt}" if request_prompt else "")
+ f"<|startoftranscript|><|{language}|>"
+ f"<|{task_type}|><|notimestamps|>"
decoder_text = (
f"<|prev|>{request_prompt}" if request_prompt else ""
) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>"
return ExplicitEncoderDecoderPrompt(
encoder_prompt=TextPrompt(
prompt="", # Whisper does not support encoder prompt.
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
),
}
return cast(PromptType, prompt)
decoder_prompt=TextPrompt(prompt=decoder_text),
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None: