adapt voxtral (#31095)

Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Patrick von Platen
2025-12-23 14:31:55 +01:00
committed by GitHub
parent b10d47e0e0
commit 3faa8bee57
12 changed files with 739 additions and 98 deletions

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
@@ -116,10 +117,7 @@ class VoxtralProcessorAdapter:
self,
audio_length: int,
) -> int:
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
audio_length, self.sampling_rate
)
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
def __call__(
self,
@@ -158,7 +156,14 @@ class VoxtralProcessorAdapter:
assert audio.ndim == 1
# pad if necessary
audio = self._audio_processor.pad(audio, self.sampling_rate)
# TODO(Patrick) - remove once mistral-common is bumped
sig = inspect.signature(self._audio_processor.pad)
if "is_online_streaming" in sig.parameters:
audio = self._audio_processor.pad(
audio, self.sampling_rate, is_online_streaming=False
)
else:
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id
@@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration(
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
remapping_rules = [
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
(r"mm_whisper_embeddings\.(.*)", r"\1"),
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
(
@@ -535,13 +541,16 @@ class VoxtralForConditionalGeneration(
def llm_weights_generator():
nonlocal loaded_weights
for name, w in weights:
is_encoder = (
name.startswith("mm_whisper_embeddings")
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
and not name.startswith(
"mm_whisper_embeddings.audio_language_projection"
is_encoder = False
for k in [
"mm_whisper_embeddings",
"mm_streams_embeddings.embedding_module",
]:
is_encoder |= (
name.startswith(k)
and not name.startswith(f"{k}.tok_embeddings")
and not name.startswith(f"{k}.audio_language_projection")
)
)
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
@@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module):
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
mistral_remapping = [
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
(
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
r"whisper_encoder.conv1.\1",
@@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module):
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
r"whisper_encoder.conv2.\1",
),
(
r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
r"whisper_encoder.conv1.\1",
), # noqa: E501
(
r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
r"whisper_encoder.conv2.\1",
), # noqa: E501
(
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",