[Frontend] Abstract prompt and SpeechToTextConfig for transcriptions models (#20637)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -3,8 +3,9 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Optional, TypedDict, Union
|
||||
from typing import Optional, TypedDict, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
@@ -12,8 +13,10 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
|
||||
VllmConfig)
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
@@ -33,6 +36,7 @@ from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.transformers_utils.processor import cached_get_processor
|
||||
|
||||
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
|
||||
SupportsTranscription, SupportsV0Only)
|
||||
@@ -785,11 +789,24 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
f"or {list(ISO639_1_OTHER_LANGS.values())}")
|
||||
|
||||
@classmethod
|
||||
def get_decoder_prompt(cls, language: str, task_type: str,
|
||||
prompt: str) -> str:
|
||||
return ((f"<|prev|>{prompt}" if prompt else "") +
|
||||
f"<|startoftranscript|><|{language}|>" +
|
||||
f"<|{task_type}|><|notimestamps|>")
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
stt_config: SpeechToTextConfig, language: str,
|
||||
task_type: str,
|
||||
request_prompt: str) -> PromptType:
|
||||
prompt = {
|
||||
"encoder_prompt": {
|
||||
# Whisper does not support encoder prompt.
|
||||
"prompt": "",
|
||||
"multi_modal_data": {
|
||||
"audio": (audio, stt_config.sample_rate),
|
||||
},
|
||||
},
|
||||
"decoder_prompt":
|
||||
((f"<|prev|>{request_prompt}" if request_prompt else "") +
|
||||
f"<|startoftranscript|><|{language}|>" +
|
||||
f"<|{task_type}|><|notimestamps|>")
|
||||
}
|
||||
return cast(PromptType, prompt)
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
@@ -798,6 +815,30 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
|
||||
raise ValueError("Only audio modality is supported")
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
processor = cached_get_processor(model_config.model)
|
||||
|
||||
return SpeechToTextConfig(
|
||||
max_audio_clip_s=processor.feature_extractor.chunk_length,
|
||||
sample_rate=processor.feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig) -> Optional[int]:
|
||||
processor = cached_get_processor(model_config.model)
|
||||
hop_length = processor.feature_extractor.hop_length
|
||||
assert hop_length is not None
|
||||
# NOTE(NickLucche) user can't pass encoder
|
||||
# prompts directly at least not to Whisper.
|
||||
# One indicator of the encoder amount of processing
|
||||
# is the log-mel spectogram length.
|
||||
return math.ceil(audio_duration_s * stt_config.sample_rate /
|
||||
hop_length)
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
|
||||
Reference in New Issue
Block a user