Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Patrick von Platen
2025-07-15 16:35:30 +02:00
committed by GitHub
parent 4ffd963fa0
commit e7e3e6d263
14 changed files with 913 additions and 47 deletions

View File

@@ -3,6 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from typing import Optional, TypedDict, Union, cast
import numpy as np
@@ -13,6 +14,7 @@ from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig)
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -26,6 +28,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
@@ -178,6 +181,7 @@ class WhisperAttention(nn.Module):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
standalone_encoder: bool = False,
):
super().__init__()
self.embed_dim = embed_dim
@@ -213,16 +217,24 @@ class WhisperAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
if standalone_encoder:
self.attn = MultiHeadAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
)
else:
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
def _init_qkv(
self,
@@ -357,7 +369,11 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
@@ -371,6 +387,7 @@ class WhisperEncoderLayer(nn.Module):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
standalone_encoder=is_standalone_encoder,
)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.mlp = WhisperMLP(
@@ -462,10 +479,16 @@ class WhisperDecoderLayer(nn.Module):
class WhisperEncoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
is_standalone_encoder: bool = False,
init_in_fp32: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.is_standalone_encoder = is_standalone_encoder
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim)
@@ -480,17 +503,25 @@ class WhisperEncoder(nn.Module):
kernel_size=3,
stride=2,
padding=1)
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers"),
prefix=f"{prefix}.layers",
is_standalone_encoder=
is_standalone_encoder),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
with torch.no_grad():
maybe_fp32_init_ctx = set_default_torch_dtype(
torch.float32) if init_in_fp32 else nullcontext()
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))
@@ -499,8 +530,10 @@ class WhisperEncoder(nn.Module):
for features in input_features:
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.permute(1, 0)
embeds = embeds + self.embed_positions.weight[:embeds.size(0), :]
embeds = embeds.transpose(-1, -2)
embeds = (embeds +
self.embed_positions.weight[:embeds.size(-2), :]).to(
embeds.dtype)
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
@@ -792,10 +825,14 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
f"or {list(ISO639_1_OTHER_LANGS.values())}")
@classmethod
def get_generation_prompt(cls, audio: np.ndarray,
stt_config: SpeechToTextConfig, language: str,
task_type: str,
request_prompt: str) -> PromptType:
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str,
task_type: str,
request_prompt: str) -> PromptType:
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.