Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -12,8 +12,12 @@ import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mistral_common.audio import mel_filter_bank
|
||||
from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
|
||||
TextChunk, UserMessage)
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
|
||||
@@ -28,23 +32,37 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import SupportsPP
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
|
||||
# yapf: disable
|
||||
from vllm.model_executor.models.whisper import WhisperEncoder
|
||||
|
||||
# yapf: enable
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems, MultiModalUUIDDict,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PromptReplacement, PromptUpdate)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalUUIDDict,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioProcessorItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
MultiModalProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
)
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
||||
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
|
||||
@@ -109,7 +127,8 @@ class VoxtralProcessorAdapter:
|
||||
audio_length: int,
|
||||
) -> int:
|
||||
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
|
||||
audio_length, self.sampling_rate)
|
||||
audio_length, self.sampling_rate
|
||||
)
|
||||
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
|
||||
|
||||
def __call__(
|
||||
@@ -139,7 +158,8 @@ class VoxtralProcessorAdapter:
|
||||
"Make sure to process your input via `mistral_common`'s "
|
||||
"tokenizer or pass a chat completion request. "
|
||||
"For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411.")
|
||||
"https://github.com/vllm-project/vllm/issues/8411."
|
||||
)
|
||||
|
||||
audios_tokens = list[torch.Tensor]()
|
||||
audios_processed = list[torch.Tensor]()
|
||||
@@ -150,23 +170,22 @@ class VoxtralProcessorAdapter:
|
||||
# pad if necessary
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
|
||||
audio_tokens = [
|
||||
self.begin_audio_token_id
|
||||
] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio))
|
||||
audio_tokens = [self.begin_audio_token_id] + [
|
||||
self.audio_token_id
|
||||
] * self.get_num_audio_tokens(len(audio))
|
||||
|
||||
audios_tokens.append(torch.tensor(audio_tokens))
|
||||
audios_processed.append(torch.tensor(audio))
|
||||
|
||||
return BatchFeature({
|
||||
"input_ids":
|
||||
torch.cat(audios_tokens)[None].expand(len(text), -1),
|
||||
"audio_arrays":
|
||||
audios_processed,
|
||||
})
|
||||
return BatchFeature(
|
||||
{
|
||||
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
|
||||
"audio_arrays": audios_processed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_tokenizer(self) -> MistralTokenizer:
|
||||
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
||||
if not isinstance(tokenizer, MistralTokenizer):
|
||||
@@ -193,11 +212,11 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||
def get_max_audio_array_len(self) -> int:
|
||||
processor = self.get_hf_processor()
|
||||
return self.get_max_audio_tokens() * int(
|
||||
processor.sampling_rate // processor.frame_rate)
|
||||
processor.sampling_rate // processor.frame_rate
|
||||
)
|
||||
|
||||
|
||||
class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
return ""
|
||||
|
||||
@@ -214,10 +233,9 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=target_length,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
"audio": self._get_dummy_audios(
|
||||
length=target_length, num_audios=num_audios, overrides=audio_overrides
|
||||
)
|
||||
}
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
@@ -243,9 +261,11 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
|
||||
audio_chunks.append(chunk)
|
||||
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
|
||||
])
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
|
||||
]
|
||||
)
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
dummy_tokens = res.tokens
|
||||
# whixtral tokenizer adds padding to the audio
|
||||
@@ -255,9 +275,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
|
||||
|
||||
|
||||
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
):
|
||||
|
||||
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: Mapping[str, NestedTensors],
|
||||
@@ -315,17 +333,19 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
|
||||
return MultiModalDataParser(target_sr=sampling_rate)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor,
|
||||
info=VoxtralProcessingInfo,
|
||||
dummy_inputs=VoxtralDummyInputsBuilder)
|
||||
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
SupportsPP, SupportsLoRA,
|
||||
SupportsTranscription):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
VoxtralMultiModalProcessor,
|
||||
info=VoxtralProcessingInfo,
|
||||
dummy_inputs=VoxtralDummyInputsBuilder,
|
||||
)
|
||||
class VoxtralForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
|
||||
):
|
||||
supported_languages = ISO639_1_SUPPORTED_LANGS
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"]
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -336,7 +356,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# match the vLLM model names
|
||||
if hasattr(vllm_config, "quant_config"):
|
||||
vllm_config.quant_config = self.maybe_update_quant_config(
|
||||
vllm_config.quant_config)
|
||||
vllm_config.quant_config
|
||||
)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.config = config
|
||||
@@ -378,17 +399,15 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
intermediate_tensors,
|
||||
inputs_embeds=inputs_embeds)
|
||||
hidden_states = self.language_model.model(
|
||||
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs
|
||||
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...],
|
||||
None]:
|
||||
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], None]:
|
||||
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
|
||||
if audio_inputs is None:
|
||||
return None
|
||||
@@ -399,34 +418,36 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
seq_len, dim = audio_embedding.shape
|
||||
# Pad such that seq_len is divisible by downsample_factor
|
||||
target_seq_len = self.downsample_factor * math.ceil(
|
||||
seq_len / self.downsample_factor)
|
||||
seq_len / self.downsample_factor
|
||||
)
|
||||
audio_embedding = torch.nn.functional.pad(
|
||||
audio_embedding,
|
||||
(0, 0, 0, target_seq_len - seq_len),
|
||||
)
|
||||
audio_embeddings[i] = audio_embedding.reshape(
|
||||
target_seq_len // self.downsample_factor,
|
||||
dim * self.downsample_factor)
|
||||
target_seq_len // self.downsample_factor, dim * self.downsample_factor
|
||||
)
|
||||
|
||||
# Concat, project and resplit
|
||||
audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
|
||||
audio_embeddings_packed = self.audio_language_adapter(
|
||||
audio_embeddings_packed)
|
||||
audio_embeddings = torch.split(audio_embeddings_packed,
|
||||
[a.shape[0] for a in audio_embeddings],
|
||||
dim=0)
|
||||
audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed)
|
||||
audio_embeddings = torch.split(
|
||||
audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0
|
||||
)
|
||||
|
||||
return audio_embeddings
|
||||
|
||||
def _parse_and_validate_audio_arrays(
|
||||
self, **kwargs: object) -> Union[list[torch.Tensor], None]:
|
||||
self, **kwargs: object
|
||||
) -> Union[list[torch.Tensor], None]:
|
||||
audio_arrays = kwargs.pop("audio_arrays", None)
|
||||
if audio_arrays is None:
|
||||
return None
|
||||
|
||||
if not isinstance(audio_arrays, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_arrays. "
|
||||
f"Got type: {type(audio_arrays)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
|
||||
)
|
||||
|
||||
audio_arrays = flatten_bn(audio_arrays)
|
||||
if isinstance(audio_arrays, torch.Tensor):
|
||||
@@ -440,8 +461,9 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
@classmethod
|
||||
def get_speech_to_text_config(cls, model_config: ModelConfig,
|
||||
task_type: str) -> SpeechToTextConfig:
|
||||
def get_speech_to_text_config(
|
||||
cls, model_config: ModelConfig, task_type: str
|
||||
) -> SpeechToTextConfig:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio_config = tokenizer.instruct.audio_encoder.audio_config
|
||||
max_audio_clip_s = audio_config.chunk_length_s
|
||||
@@ -455,19 +477,23 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
@classmethod
|
||||
# for speech-to-text transcription
|
||||
def get_generation_prompt(cls, audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str]) -> PromptType:
|
||||
def get_generation_prompt(
|
||||
cls,
|
||||
audio: np.ndarray,
|
||||
model_config: ModelConfig,
|
||||
stt_config: SpeechToTextConfig,
|
||||
language: Optional[str],
|
||||
task_type: Literal["transcribe", "translate"],
|
||||
request_prompt: str,
|
||||
to_language: Optional[str],
|
||||
) -> PromptType:
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
audio = Audio(audio, int(stt_config.sample_rate),
|
||||
format="wav") # lossless
|
||||
req = TranscriptionRequest(model=model_config.model,
|
||||
audio=RawAudio.from_audio(audio),
|
||||
language=language)
|
||||
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
|
||||
req = TranscriptionRequest(
|
||||
model=model_config.model,
|
||||
audio=RawAudio.from_audio(audio),
|
||||
language=language,
|
||||
)
|
||||
|
||||
tokenized = tokenizer.instruct.encode_transcription(req)
|
||||
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
|
||||
@@ -476,21 +502,24 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return cast(PromptType, prompts_dict)
|
||||
|
||||
@classmethod
|
||||
def get_num_audio_tokens(cls, audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig) -> Optional[int]:
|
||||
def get_num_audio_tokens(
|
||||
cls,
|
||||
audio_duration_s: float,
|
||||
stt_config: SpeechToTextConfig,
|
||||
model_config: ModelConfig,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Map from audio duration to number of audio tokens produced by the ASR
|
||||
Map from audio duration to number of audio tokens produced by the ASR
|
||||
model, without running a forward pass.
|
||||
This is used for estimating the amount of processing for this audio.
|
||||
"""
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
adapter = VoxtralProcessorAdapter(tokenizer)
|
||||
return adapter.get_num_audio_tokens(
|
||||
int(audio_duration_s * stt_config.sample_rate))
|
||||
int(audio_duration_s * stt_config.sample_rate)
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
# fmt: off
|
||||
remapping_rules = [
|
||||
(r"mm_whisper_embeddings\.(.*)", r"\1"),
|
||||
@@ -501,10 +530,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
# fmt: on
|
||||
|
||||
audio_params = dict(
|
||||
nn.ModuleDict({
|
||||
"audio_language_adapter":
|
||||
self.audio_language_adapter,
|
||||
}).named_parameters())
|
||||
nn.ModuleDict(
|
||||
{
|
||||
"audio_language_adapter": self.audio_language_adapter,
|
||||
}
|
||||
).named_parameters()
|
||||
)
|
||||
|
||||
loaded_weights = set()
|
||||
|
||||
@@ -512,10 +543,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
nonlocal loaded_weights
|
||||
for name, w in weights:
|
||||
is_encoder = (
|
||||
name.startswith("mm_whisper_embeddings") and
|
||||
not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
||||
name.startswith("mm_whisper_embeddings")
|
||||
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
|
||||
and not name.startswith(
|
||||
"mm_whisper_embeddings.audio_language_projection"))
|
||||
"mm_whisper_embeddings.audio_language_projection"
|
||||
)
|
||||
)
|
||||
|
||||
for pattern, repl in remapping_rules:
|
||||
if re.fullmatch(pattern, name):
|
||||
@@ -546,7 +579,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
return loaded_weights
|
||||
|
||||
def maybe_update_quant_config(
|
||||
self, quant_config: QuantizationConfig) -> QuantizationConfig:
|
||||
self, quant_config: QuantizationConfig
|
||||
) -> QuantizationConfig:
|
||||
"""
|
||||
Update quant config to so that ignored module and target module names
|
||||
match the vLLM model names.
|
||||
@@ -555,32 +589,54 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
"""
|
||||
remapping_rules = [
|
||||
(r"output", r"language_model.lm_head"),
|
||||
(r"layers\.(\d+)\.attention\.wo",
|
||||
r"language_model.model.layers.\1.self_attn.out_proj"),
|
||||
(r"layers\.(\d+)\.attention\.w(.*)",
|
||||
r"language_model.model.layers.\1.self_attn.\2_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w1",
|
||||
r"language_model.model.layers.\1.mlp.gate_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w2",
|
||||
r"language_model.model.layers.\1.mlp.down_proj"),
|
||||
(r"layers\.(\d+)\.feed_forward\.w3",
|
||||
r"language_model.model.layers.\1.mlp.up_proj"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
|
||||
),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
|
||||
),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
|
||||
r"whisper_encoder.whisper_encoder.conv1"),
|
||||
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
|
||||
r"whisper_encoder.whisper_encoder.conv2"),
|
||||
(r"mm_whisper_embeddings\.audio_language_projection\.0",
|
||||
r"audio_language_adapter.w_in"),
|
||||
(r"mm_whisper_embeddings\.audio_language_projection\.2",
|
||||
r"audio_language_adapter.w_out"),
|
||||
(
|
||||
r"layers\.(\d+)\.attention\.wo",
|
||||
r"language_model.model.layers.\1.self_attn.out_proj",
|
||||
),
|
||||
(
|
||||
r"layers\.(\d+)\.attention\.w(.*)",
|
||||
r"language_model.model.layers.\1.self_attn.\2_proj",
|
||||
),
|
||||
(
|
||||
r"layers\.(\d+)\.feed_forward\.w1",
|
||||
r"language_model.model.layers.\1.mlp.gate_proj",
|
||||
),
|
||||
(
|
||||
r"layers\.(\d+)\.feed_forward\.w2",
|
||||
r"language_model.model.layers.\1.mlp.down_proj",
|
||||
),
|
||||
(
|
||||
r"layers\.(\d+)\.feed_forward\.w3",
|
||||
r"language_model.model.layers.\1.mlp.up_proj",
|
||||
),
|
||||
(
|
||||
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
|
||||
),
|
||||
(
|
||||
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
|
||||
),
|
||||
(
|
||||
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
|
||||
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
|
||||
),
|
||||
(
|
||||
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
|
||||
r"whisper_encoder.whisper_encoder.conv1",
|
||||
),
|
||||
(
|
||||
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
|
||||
r"whisper_encoder.whisper_encoder.conv2",
|
||||
),
|
||||
(
|
||||
r"mm_whisper_embeddings\.audio_language_projection\.0",
|
||||
r"audio_language_adapter.w_in",
|
||||
),
|
||||
(
|
||||
r"mm_whisper_embeddings\.audio_language_projection\.2",
|
||||
r"audio_language_adapter.w_out",
|
||||
),
|
||||
]
|
||||
|
||||
# Update ignore list
|
||||
@@ -613,7 +669,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
|
||||
class AudioLanguageAdapter(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, dim: int) -> None:
|
||||
super().__init__()
|
||||
self.w_in = nn.Linear(hidden_size, dim, bias=False)
|
||||
@@ -650,10 +705,11 @@ class VoxtralEncoderModel(nn.Module):
|
||||
super().__init__()
|
||||
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
|
||||
self.dtype: torch.dtype = vllm_config.model_config.dtype
|
||||
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "whisper_encoder"),
|
||||
init_in_fp32=True)
|
||||
self.whisper_encoder = WhisperEncoder(
|
||||
vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "whisper_encoder"),
|
||||
init_in_fp32=True,
|
||||
)
|
||||
mel_filters = mel_filter_bank(
|
||||
num_frequency_bins=1 + self.config.window_size // 2,
|
||||
num_mel_bins=self.config.num_mel_bins,
|
||||
@@ -668,8 +724,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
audio_waveforms: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
input_dtype = audio_waveforms.dtype
|
||||
window = torch.hann_window(self.config.window_size).to(
|
||||
audio_waveforms.device)
|
||||
window = torch.hann_window(self.config.window_size).to(audio_waveforms.device)
|
||||
stft = torch.stft(
|
||||
audio_waveforms,
|
||||
self.config.window_size,
|
||||
@@ -677,7 +732,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
window=window,
|
||||
return_complex=True,
|
||||
)
|
||||
magnitudes = stft[..., :-1].abs()**2
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
mel_spec = self.mel_filters.T @ magnitudes
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
@@ -686,8 +741,9 @@ class VoxtralEncoderModel(nn.Module):
|
||||
|
||||
@property
|
||||
def downsample_factor(self) -> int:
|
||||
return self.whisper_encoder.conv1.stride[
|
||||
0] * self.whisper_encoder.conv2.stride[0]
|
||||
return (
|
||||
self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0]
|
||||
)
|
||||
|
||||
@property
|
||||
def chunk_size(self) -> int:
|
||||
@@ -721,8 +777,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
input_features = [input_features]
|
||||
|
||||
# Split long inputs into chunks
|
||||
input_embeds, chunks_per_example = (
|
||||
self.prepare_inputs_for_conv(input_features))
|
||||
input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features)
|
||||
|
||||
# [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
|
||||
out = self.whisper_encoder([input_embeds])
|
||||
@@ -731,7 +786,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
chunk_idx = 0
|
||||
results = []
|
||||
for n_chunks in chunks_per_example:
|
||||
result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1)
|
||||
result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1)
|
||||
results.append(result)
|
||||
chunk_idx += n_chunks
|
||||
|
||||
@@ -751,7 +806,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
if re.fullmatch(pattern, name):
|
||||
name = re.sub(pattern, repl, name)
|
||||
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
@@ -762,8 +817,7 @@ class VoxtralEncoderModel(nn.Module):
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
return name
|
||||
|
||||
Reference in New Issue
Block a user