[Model] Use explicit types in get_generation_prompt (#33551)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,7 +5,7 @@ import enum
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from typing import Annotated, Literal, cast
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -21,7 +21,7 @@ from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.inputs.data import ExplicitEncoderDecoderPrompt, PromptType, TextPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import (
|
||||
@@ -815,21 +815,18 @@ class WhisperForConditionalGeneration(
|
||||
raise ValueError(
|
||||
"Language must be specified when creating the Whisper prompt"
|
||||
)
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
# Whisper does not support encoder prompt.
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (audio, stt_config.sample_rate),
|
||||
},
|
||||
},
|
||||
"decoder_prompt": (
|
||||
(f"<|prev|>{request_prompt}" if request_prompt else "")
|
||||
+ f"<|startoftranscript|><|{language}|>"
|
||||
+ f"<|{task_type}|><|notimestamps|>"
|
||||
|
||||
decoder_text = (
|
||||
f"<|prev|>{request_prompt}" if request_prompt else ""
|
||||
) + f"<|startoftranscript|><|{language}|><|{task_type}|><|notimestamps|>"
|
||||
|
||||
return ExplicitEncoderDecoderPrompt(
|
||||
encoder_prompt=TextPrompt(
|
||||
prompt="", # Whisper does not support encoder prompt.
|
||||
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
|
||||
),
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
decoder_prompt=TextPrompt(prompt=decoder_text),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
|
||||
Reference in New Issue
Block a user