[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (#20637)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-07-12 06:33:26 +02:00
committed by GitHub
parent 890323dc1b
commit 3c7d942da8
4 changed files with 141 additions and 60 deletions

View File

@@ -3,8 +3,9 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Optional, TypedDict, Union
from typing import Optional, TypedDict, Union, cast
import numpy as np
import torch
from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
@@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_get_processor
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription, SupportsV0Only)
@@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
f"or {list(ISO639_1_OTHER_LANGS.values())}")
@classmethod
def get_decoder_prompt(cls, language: str, task_type: str,
prompt: str) -> str:
return ((f"<|prev|>{prompt}" if prompt else "") +
f"<|startoftranscript|><|{language}|>" +
f"<|{task_type}|><|notimestamps|>")
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
request_prompt: str) -> PromptType:
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.
"prompt": "",
"multi_modal_data": {
"audio": (audio, stt_config.sample_rate),
},
},
"decoder_prompt":
((f"<|prev|>{request_prompt}" if request_prompt else "") +
f"<|startoftranscript|><|{language}|>" +
f"<|{task_type}|><|notimestamps|>")
}
return cast(PromptType, prompt)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
raise ValueError("Only audio modality is supported")
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
processor = cached_get_processor(model_config.model)
return SpeechToTextConfig(
max_audio_clip_s=processor.feature_extractor.chunk_length,
sample_rate=processor.feature_extractor.sampling_rate,
)
@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig) -> Optional[int]:
processor = cached_get_processor(model_config.model)
hop_length = processor.feature_extractor.hop_length
assert hop_length is not None
# NOTE(NickLucche) user can't pass encoder
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
return math.ceil(audio_duration_s * stt_config.sample_rate /
hop_length)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config