[Voxtral] Add new streaming arch (#32861)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
5da4c7d789
commit
3f3f89529d
@@ -4,7 +4,7 @@
|
||||
import inspect
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from functools import cached_property, partial
|
||||
from math import ceil
|
||||
from typing import Literal, cast
|
||||
|
||||
@@ -33,7 +33,11 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||
from vllm.model_executor.models.whisper import (
|
||||
WhisperEncoder,
|
||||
_create_fake_bias_for_k_proj,
|
||||
)
|
||||
from vllm.model_executor.models.whisper_causal import WhisperCausalEncoder
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
@@ -543,6 +547,7 @@ class VoxtralForConditionalGeneration(
|
||||
}
|
||||
).named_parameters()
|
||||
)
|
||||
weights = _create_fake_bias_for_k_proj(weights, ".wk.weight")
|
||||
|
||||
loaded_weights = set()
|
||||
|
||||
@@ -730,6 +735,10 @@ class VoxtralEncoderModel(nn.Module):
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w2\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.mlp.fc2.\2",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward\.w3\.(weight|bias)",
|
||||
r"whisper_encoder.layers.\1.mlp.fc3.\2",
|
||||
), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.ffn_norm\.(weight|bias)",
|
||||
r"whisper_encoder.layers.\1.final_layer_norm.\2",
|
||||
@@ -749,10 +758,15 @@ class VoxtralEncoderModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
|
||||
self.dtype: torch.dtype = vllm_config.model_config.dtype
|
||||
self.whisper_encoder = WhisperEncoder(
|
||||
self.is_causal = getattr(self.config, "is_causal", False)
|
||||
if self.is_causal:
|
||||
WhisperEncoderCls = WhisperCausalEncoder
|
||||
else:
|
||||
WhisperEncoderCls = partial(WhisperEncoder, init_in_fp32=True)
|
||||
|
||||
self.whisper_encoder = WhisperEncoderCls(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
||||
init_in_fp32=True,
|
||||
)
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=1 + self.config.window_size // 2,
|
||||
@@ -843,6 +857,22 @@ class VoxtralEncoderModel(nn.Module):
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
params_mapping = []
|
||||
|
||||
if self.is_causal:
|
||||
# For `WhisperCausalEncoder` we need
|
||||
# some more renaming
|
||||
stacked_params_mapping.extend(
|
||||
[
|
||||
(".mlp.gate_up_proj", ".mlp.fc1", 0),
|
||||
(".mlp.gate_up_proj", ".mlp.fc3", 1),
|
||||
]
|
||||
)
|
||||
params_mapping.extend(
|
||||
[
|
||||
(".mlp.down_proj", ".mlp.fc2"),
|
||||
]
|
||||
)
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
name, loaded_weight = weight
|
||||
@@ -860,6 +890,11 @@ class VoxtralEncoderModel(nn.Module):
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
for param_name, weight_name in params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user