[Voxtral] Add new streaming arch (#32861)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2026-01-23 12:41:52 +01:00
committed by GitHub
parent 5da4c7d789
commit 3f3f89529d
9 changed files with 767 additions and 313 deletions

View File

@@ -4,7 +4,7 @@
import inspect
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from functools import cached_property, partial
from math import ceil
from typing import Literal, cast
@@ -33,7 +33,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.whisper import WhisperEncoder
from vllm.model_executor.models.whisper import (
WhisperEncoder,
_create_fake_bias_for_k_proj,
)
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalDataDict,
@@ -543,6 +547,7 @@ class VoxtralForConditionalGeneration(
}
).named_parameters()
)
weights = _create_fake_bias_for_k_proj(weights, ".wk.weight")
loaded_weights = set()
@@ -730,6 +735,10 @@ class VoxtralEncoderModel(nn.Module):
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.mlp.fc2.\2",
),
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(weight|bias)",
r"whisper_encoder.layers.\1.mlp.fc3.\2",
), # noqa: E501
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
r"whisper_encoder.layers.\1.final_layer_norm.\2",
@@ -749,10 +758,15 @@ class VoxtralEncoderModel(nn.Module):
super().__init__()
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
self.dtype: torch.dtype = vllm_config.model_config.dtype
self.whisper_encoder = WhisperEncoder(
self.is_causal = getattr(self.config, "is_causal", False)
if self.is_causal:
WhisperEncoderCls = WhisperCausalEncoder
else:
WhisperEncoderCls = partial(WhisperEncoder, init_in_fp32=True)
self.whisper_encoder = WhisperEncoderCls(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "whisper_encoder"),
init_in_fp32=True,
)
mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2,
@@ -843,6 +857,22 @@ class VoxtralEncoderModel(nn.Module):
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
params_mapping = []
if self.is_causal:
# For `WhisperCausalEncoder` we need
# some more renaming
stacked_params_mapping.extend(
[
(".mlp.gate_up_proj", ".mlp.fc1", 0),
(".mlp.gate_up_proj", ".mlp.fc3", 1),
]
)
params_mapping.extend(
[
(".mlp.down_proj", ".mlp.fc2"),
]
)
params_dict = dict(self.named_parameters())
name, loaded_weight = weight
@@ -860,6 +890,11 @@ class VoxtralEncoderModel(nn.Module):
weight_loader(param, loaded_weight, shard_id)
break
else:
for param_name, weight_name in params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)