[Voxtral] Add new streaming arch (#32861)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5da4c7d789
commit
3f3f89529d
@@ -5,7 +5,6 @@ import enum
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from contextlib import nullcontext
|
||||
from functools import partial
|
||||
from typing import Annotated, Literal, cast
|
||||
|
||||
import numpy as np
|
||||
@@ -39,8 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.whisper_utils import (
|
||||
ISO639_1_SUPPORTED_LANGS,
|
||||
WhisperAttentionWithBlockPooling,
|
||||
WhisperCausalConv1d,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
@@ -78,7 +75,7 @@ logger = init_logger(__name__)
|
||||
|
||||
class WhisperPosEmbedType(enum.Enum):
|
||||
SINUSOIDAL = "sinusoidal"
|
||||
NOPE = "nope"
|
||||
ROPE = "rope"
|
||||
LEARNED = "learned"
|
||||
|
||||
|
||||
@@ -140,7 +137,6 @@ class WhisperAttention(nn.Module):
|
||||
bias: bool = True,
|
||||
attn_type: AttentionType = AttentionType.DECODER,
|
||||
per_layer_sliding_window: int | None = None,
|
||||
block_pool_size: int = 1,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
@@ -199,14 +195,7 @@ class WhisperAttention(nn.Module):
|
||||
attn_type=self.attn_type,
|
||||
)
|
||||
else: # AttentionType.DECODER (regular decoder self-attention)
|
||||
if block_pool_size > 1:
|
||||
attn_cls = partial(
|
||||
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
|
||||
)
|
||||
else:
|
||||
attn_cls = Attention
|
||||
|
||||
self.attn = attn_cls(
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
@@ -351,9 +340,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
is_causal = getattr(config, "is_causal", False)
|
||||
sliding_window = getattr(config, "sliding_window", None)
|
||||
block_pool_size = getattr(config, "block_pool_size", 1)
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
@@ -361,8 +348,7 @@ class WhisperEncoderLayer(nn.Module):
|
||||
self.self_attn = WhisperAttention(
|
||||
embed_dim=self.embed_dim,
|
||||
num_heads=config.encoder_attention_heads,
|
||||
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
|
||||
block_pool_size=block_pool_size,
|
||||
attn_type=AttentionType.ENCODER,
|
||||
per_layer_sliding_window=sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
@@ -470,13 +456,8 @@ class WhisperEncoder(nn.Module):
|
||||
self.max_source_positions = config.max_source_positions
|
||||
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
|
||||
|
||||
self.is_causal = getattr(config, "is_causal", False)
|
||||
Conv1d = (
|
||||
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
|
||||
)
|
||||
|
||||
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
|
||||
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
|
||||
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
||||
self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
|
||||
|
||||
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
|
||||
self.start_layer, self.end_layer, self.layers = make_layers(
|
||||
@@ -488,33 +469,29 @@ class WhisperEncoder(nn.Module):
|
||||
)
|
||||
self.layer_norm = nn.LayerNorm(config.d_model)
|
||||
|
||||
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
|
||||
raise ValueError(
|
||||
"Only NOPE position embeddings are supported "
|
||||
f"for causal models, but got {self.pos_embed_type}"
|
||||
)
|
||||
elif self.pos_embed_type in (
|
||||
if self.pos_embed_type not in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32)
|
||||
if init_in_fp32
|
||||
else nullcontext()
|
||||
raise ValueError(
|
||||
"Only sinusoidal or learned position embeddings are supported "
|
||||
f"for non-causal models, but got {self.pos_embed_type}"
|
||||
)
|
||||
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(
|
||||
self.max_source_positions, embed_dim
|
||||
)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
maybe_fp32_init_ctx = (
|
||||
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
|
||||
)
|
||||
|
||||
def forward_conv(
|
||||
with (
|
||||
torch.no_grad(),
|
||||
maybe_fp32_init_ctx,
|
||||
):
|
||||
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
|
||||
self.embed_positions.weight.copy_(
|
||||
sinusoids(*self.embed_positions.weight.shape)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, input_features: torch.Tensor | list[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
hidden_states = []
|
||||
@@ -523,44 +500,26 @@ class WhisperEncoder(nn.Module):
|
||||
embeds = nn.functional.gelu(self.conv1(features))
|
||||
embeds = nn.functional.gelu(self.conv2(embeds))
|
||||
|
||||
if self.pos_embed_type in (
|
||||
WhisperPosEmbedType.SINUSOIDAL,
|
||||
WhisperPosEmbedType.LEARNED,
|
||||
):
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (
|
||||
embeds + self.embed_positions.weight[: embeds.size(-2), :]
|
||||
).to(embeds.dtype)
|
||||
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
|
||||
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
|
||||
else:
|
||||
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
|
||||
embeds = embeds.transpose(-1, -2)
|
||||
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
|
||||
embeds.dtype
|
||||
)
|
||||
|
||||
hidden_states.append(embeds)
|
||||
input_is_batched = embeds.ndim > 2
|
||||
# Input to MHA must be B x T x D
|
||||
if input_is_batched or self.is_causal:
|
||||
if input_is_batched:
|
||||
# Models using WhisperEncoder may handle batching internally.
|
||||
# If WhisperEncoder is causal, sequences
|
||||
# are not padded to have identical seq length (T)
|
||||
# => concat over feature dim
|
||||
hidden_states = torch.cat(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.stack(hidden_states, dim=0)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
|
||||
hidden_states = self.forward_conv(input_features)
|
||||
return self.forward_layers(hidden_states)
|
||||
|
||||
|
||||
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
|
||||
class WhisperDecoder(nn.Module):
|
||||
@@ -978,19 +937,19 @@ class WhisperForConditionalGeneration(
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
|
||||
|
||||
# add fake zeros bias for k_proj to state_dict
|
||||
weights = _create_fake_bias_for_k_proj(weights)
|
||||
weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
|
||||
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
|
||||
def _create_fake_bias_for_k_proj(
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
|
||||
So that the bias for k_proj in qkv_proj can be initialized with zeros.
|
||||
"""
|
||||
for name, weight in weights:
|
||||
if name.endswith(".k_proj.weight"):
|
||||
if name.endswith(fake_bias_key_name):
|
||||
bias = torch.zeros(weight.size(0))
|
||||
bias_name = name.replace("weight", "bias")
|
||||
yield from [(name, weight), (bias_name, bias)]
|
||||
|
||||
Reference in New Issue
Block a user