[Voxtral] Add new streaming arch (#32861)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-01-23 12:41:52 +01:00
committed by GitHub
parent 5da4c7d789
commit 3f3f89529d
9 changed files with 767 additions and 313 deletions

View File

@@ -5,7 +5,6 @@ import enum
import math
from collections.abc import Iterable, Mapping, Sequence
from contextlib import nullcontext
from functools import partial
from typing import Annotated, Literal, cast
import numpy as np
@@ -39,8 +38,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.whisper_utils import (
ISO639_1_SUPPORTED_LANGS,
WhisperAttentionWithBlockPooling,
WhisperCausalConv1d,
)
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
@@ -78,7 +75,7 @@ logger = init_logger(__name__)
class WhisperPosEmbedType(enum.Enum):
SINUSOIDAL = "sinusoidal"
NOPE = "nope"
ROPE = "rope"
LEARNED = "learned"
@@ -140,7 +137,6 @@ class WhisperAttention(nn.Module):
bias: bool = True,
attn_type: AttentionType = AttentionType.DECODER,
per_layer_sliding_window: int | None = None,
block_pool_size: int = 1,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
@@ -199,14 +195,7 @@ class WhisperAttention(nn.Module):
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
if block_pool_size > 1:
attn_cls = partial(
WhisperAttentionWithBlockPooling, block_pool_size=block_pool_size
)
else:
attn_cls = Attention
self.attn = attn_cls(
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
@@ -351,9 +340,7 @@ class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
is_causal = getattr(config, "is_causal", False)
sliding_window = getattr(config, "sliding_window", None)
block_pool_size = getattr(config, "block_pool_size", 1)
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
@@ -361,8 +348,7 @@ class WhisperEncoderLayer(nn.Module):
self.self_attn = WhisperAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
attn_type=AttentionType.DECODER if is_causal else AttentionType.ENCODER,
block_pool_size=block_pool_size,
attn_type=AttentionType.ENCODER,
per_layer_sliding_window=sliding_window,
cache_config=cache_config,
quant_config=quant_config,
@@ -470,13 +456,8 @@ class WhisperEncoder(nn.Module):
self.max_source_positions = config.max_source_positions
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.is_causal = getattr(config, "is_causal", False)
Conv1d = (
WhisperCausalConv1d if self.is_causal else partial(nn.Conv1d, padding=1)
)
self.conv1 = Conv1d(self.num_mel_bins, embed_dim, kernel_size=3)
self.conv2 = Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3)
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, stride=2, kernel_size=3, padding=1)
self.total_stride = self.conv1.stride[0] * self.conv2.stride[0]
self.start_layer, self.end_layer, self.layers = make_layers(
@@ -488,33 +469,29 @@ class WhisperEncoder(nn.Module):
)
self.layer_norm = nn.LayerNorm(config.d_model)
if self.is_causal and self.pos_embed_type != WhisperPosEmbedType.NOPE:
raise ValueError(
"Only NOPE position embeddings are supported "
f"for causal models, but got {self.pos_embed_type}"
)
elif self.pos_embed_type in (
if self.pos_embed_type not in (
WhisperPosEmbedType.SINUSOIDAL,
WhisperPosEmbedType.LEARNED,
):
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32)
if init_in_fp32
else nullcontext()
raise ValueError(
"Only sinusoidal or learned position embeddings are supported "
f"for non-causal models, but got {self.pos_embed_type}"
)
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(
self.max_source_positions, embed_dim
)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
)
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
)
def forward_conv(
with (
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape)
)
def forward(
self, input_features: torch.Tensor | list[torch.Tensor]
) -> torch.Tensor:
hidden_states = []
@@ -523,44 +500,26 @@ class WhisperEncoder(nn.Module):
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
if self.pos_embed_type in (
WhisperPosEmbedType.SINUSOIDAL,
WhisperPosEmbedType.LEARNED,
):
embeds = embeds.transpose(-1, -2)
embeds = (
embeds + self.embed_positions.weight[: embeds.size(-2), :]
).to(embeds.dtype)
elif self.pos_embed_type == WhisperPosEmbedType.NOPE:
embeds = embeds.transpose(-1, -2).to(embeds.dtype)
else:
raise ValueError(f"Unknown pos_embed_type: {self.pos_embed_type}")
embeds = embeds.transpose(-1, -2)
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
embeds.dtype
)
hidden_states.append(embeds)
input_is_batched = embeds.ndim > 2
# Input to MHA must be B x T x D
if input_is_batched or self.is_causal:
if input_is_batched:
# Models using WhisperEncoder may handle batching internally.
# If WhisperEncoder is causal, sequences
# are not padded to have identical seq length (T)
# => concat over feature dim
hidden_states = torch.cat(hidden_states)
else:
hidden_states = torch.stack(hidden_states, dim=0)
return hidden_states
def forward_layers(self, hidden_states: torch.Tensor) -> torch.Tensor:
for encoder_layer in self.layers:
hidden_states = encoder_layer(hidden_states)
hidden_states = self.layer_norm(hidden_states)
return hidden_states
def forward(self, input_features: torch.Tensor | list[torch.Tensor]):
hidden_states = self.forward_conv(input_features)
return self.forward_layers(hidden_states)
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
class WhisperDecoder(nn.Module):
@@ -978,19 +937,19 @@ class WhisperForConditionalGeneration(
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
# add fake zeros bias for k_proj to state_dict
weights = _create_fake_bias_for_k_proj(weights)
weights = _create_fake_bias_for_k_proj(weights, ".k_proj.weight")
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
def _create_fake_bias_for_k_proj(
weights: Iterable[tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]], fake_bias_key_name: str
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create full zeros bias for k_proj weight in self-attn and x-attn layers.
So that the bias for k_proj in qkv_proj can be initialized with zeros.
"""
for name, weight in weights:
if name.endswith(".k_proj.weight"):
if name.endswith(fake_bias_key_name):
bias = torch.zeros(weight.size(0))
bias_name = name.replace("weight", "bias")
yield from [(name, weight), (bias_name, bias)]