adapt voxtral (#31095)
Signed-off-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
committed by
GitHub
parent
b10d47e0e0
commit
3faa8bee57
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import inspect
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
@@ -116,10 +117,7 @@ class VoxtralProcessorAdapter:
|
||||
self,
|
||||
audio_length: int,
|
||||
) -> int:
|
||||
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
|
||||
audio_length, self.sampling_rate
|
||||
)
|
||||
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
|
||||
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -158,7 +156,14 @@ class VoxtralProcessorAdapter:
|
||||
assert audio.ndim == 1
|
||||
|
||||
# pad if necessary
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
# TODO(Patrick) - remove once mistral-common is bumped
|
||||
sig = inspect.signature(self._audio_processor.pad)
|
||||
if "is_online_streaming" in sig.parameters:
|
||||
audio = self._audio_processor.pad(
|
||||
audio, self.sampling_rate, is_online_streaming=False
|
||||
)
|
||||
else:
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
|
||||
audio_tokens = [self.begin_audio_token_id] + [
|
||||
self.audio_token_id
|
||||
@@ -510,6 +515,7 @@ class VoxtralForConditionalGeneration(
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
remapping_rules = [
|
||||
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
|
||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||
(r"audio_language_projection\.(.*)", r"audio_language_adapter.\1"),
|
||||
(
|
||||
@@ -535,13 +541,16 @@ class VoxtralForConditionalGeneration(
|
||||
def llm_weights_generator():
|
||||
nonlocal loaded_weights
|
||||
for name, w in weights:
|
||||
is_encoder = (
|
||||
name.startswith("mm_whisper_embeddings")
|
||||
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
||||
and not name.startswith(
|
||||
"mm_whisper_embeddings.audio_language_projection"
|
||||
is_encoder = False
|
||||
for k in [
|
||||
"mm_whisper_embeddings",
|
||||
"mm_streams_embeddings.embedding_module",
|
||||
]:
|
||||
is_encoder |= (
|
||||
name.startswith(k)
|
||||
and not name.startswith(f"{k}.tok_embeddings")
|
||||
and not name.startswith(f"{k}.audio_language_projection")
|
||||
)
|
||||
)
|
||||
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
@@ -676,6 +685,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}
|
||||
|
||||
mistral_remapping = [
|
||||
(r"mm_streams_embeddings.embedding_module\.(.*)", r"\1"),
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.0\.(weight|bias)",
|
||||
r"whisper_encoder.conv1.\1",
|
||||
@@ -684,6 +694,14 @@ class VoxtralEncoderModel(nn.Module):
|
||||
r"whisper_encoder\.conv_layers\.1\.(weight|bias)",
|
||||
r"whisper_encoder.conv2.\1",
|
||||
),
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.0\.conv\.(weight|bias)",
|
||||
r"whisper_encoder.conv1.\1",
|
||||
), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.conv_layers\.1\.conv\.(weight|bias)",
|
||||
r"whisper_encoder.conv2.\1",
|
||||
), # noqa: E501
|
||||
(
|
||||
r"whisper_encoder\.transformer\.layers\.(\d+)\.attention\.w([qkv])\.(weight|bias)", # noqa: E501
|
||||
r"whisper_encoder.layers.\1.self_attn.\2_proj.\3",
|
||||
|
||||
Reference in New Issue
Block a user