Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -12,8 +12,12 @@ import regex as re
import torch
import torch.nn as nn
from mistral_common.audio import mel_filter_bank
from mistral_common.protocol.instruct.messages import (AudioChunk, RawAudio,
TextChunk, UserMessage)
from mistral_common.protocol.instruct.messages import (
AudioChunk,
RawAudio,
TextChunk,
UserMessage,
)
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import Audio, AudioEncoder
@@ -28,23 +32,37 @@ from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.module_mapping import MultiModelKeys
# yapf: disable
from vllm.model_executor.models.whisper import WhisperEncoder
# yapf: enable
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems, MultiModalUUIDDict,
NestedTensors)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement, PromptUpdate)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
MultiModalUUIDDict,
NestedTensors,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
MultiModalProcessingInfo,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
cached_tokenizer_from_config)
from vllm.transformers_utils.tokenizer import (
MistralTokenizer,
cached_tokenizer_from_config,
)
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import flatten_bn, init_vllm_registered_model, maybe_prefix
@@ -109,7 +127,8 @@ class VoxtralProcessorAdapter:
audio_length: int,
) -> int:
pad_audio_length = self._audio_processor.next_multiple_of_chunk_frames(
audio_length, self.sampling_rate)
audio_length, self.sampling_rate
)
return ceil(pad_audio_length / (self.sampling_rate // self.frame_rate))
def __call__(
@@ -139,7 +158,8 @@ class VoxtralProcessorAdapter:
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411.")
"https://github.com/vllm-project/vllm/issues/8411."
)
audios_tokens = list[torch.Tensor]()
audios_processed = list[torch.Tensor]()
@@ -150,23 +170,22 @@ class VoxtralProcessorAdapter:
# pad if necessary
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [
self.begin_audio_token_id
] + [self.audio_token_id] * self.get_num_audio_tokens(len(audio))
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id
] * self.get_num_audio_tokens(len(audio))
audios_tokens.append(torch.tensor(audio_tokens))
audios_processed.append(torch.tensor(audio))
return BatchFeature({
"input_ids":
torch.cat(audios_tokens)[None].expand(len(text), -1),
"audio_arrays":
audios_processed,
})
return BatchFeature(
{
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
"audio_arrays": audios_processed,
}
)
class VoxtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
if not isinstance(tokenizer, MistralTokenizer):
@@ -193,11 +212,11 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
def get_max_audio_array_len(self) -> int:
processor = self.get_hf_processor()
return self.get_max_audio_tokens() * int(
processor.sampling_rate // processor.frame_rate)
processor.sampling_rate // processor.frame_rate
)
class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
return ""
@@ -214,10 +233,9 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio":
self._get_dummy_audios(length=target_length,
num_audios=num_audios,
overrides=audio_overrides)
"audio": self._get_dummy_audios(
length=target_length, num_audios=num_audios, overrides=audio_overrides
)
}
def get_dummy_processor_inputs(
@@ -243,9 +261,11 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
audio_chunks.append(chunk)
request = ChatCompletionRequest(messages=[
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
])
request = ChatCompletionRequest(
messages=[
UserMessage(content=[TextChunk(text=dummy_text), *audio_chunks]),
]
)
res = tokenizer.mistral.encode_chat_completion(request)
dummy_tokens = res.tokens
# whixtral tokenizer adds padding to the audio
@@ -255,9 +275,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
return ProcessorInputs(prompt=dummy_tokens, mm_data=dummy_mm_data)
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
):
class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]):
def _get_mm_fields_config(
self,
hf_inputs: Mapping[str, NestedTensors],
@@ -315,17 +333,19 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]
return MultiModalDataParser(target_sr=sampling_rate)
@MULTIMODAL_REGISTRY.register_processor(VoxtralMultiModalProcessor,
info=VoxtralProcessingInfo,
dummy_inputs=VoxtralDummyInputsBuilder)
class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
SupportsPP, SupportsLoRA,
SupportsTranscription):
@MULTIMODAL_REGISTRY.register_processor(
VoxtralMultiModalProcessor,
info=VoxtralProcessingInfo,
dummy_inputs=VoxtralDummyInputsBuilder,
)
class VoxtralForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA, SupportsTranscription
):
supported_languages = ISO639_1_SUPPORTED_LANGS
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"]
"gate_up_proj": ["gate_proj", "up_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -336,7 +356,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# match the vLLM model names
if hasattr(vllm_config, "quant_config"):
vllm_config.quant_config = self.maybe_update_quant_config(
vllm_config.quant_config)
vllm_config.quant_config
)
config = vllm_config.model_config.hf_config
self.config = config
@@ -378,17 +399,15 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
if intermediate_tensors is not None:
inputs_embeds = None
hidden_states = self.language_model.model(input_ids,
positions,
intermediate_tensors,
inputs_embeds=inputs_embeds)
hidden_states = self.language_model.model(
input_ids, positions, intermediate_tensors, inputs_embeds=inputs_embeds
)
return hidden_states
def get_multimodal_embeddings(
self, **kwargs
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...],
None]:
) -> Union[list[torch.Tensor], torch.Tensor, tuple[torch.Tensor, ...], None]:
audio_inputs = self._parse_and_validate_audio_arrays(**kwargs)
if audio_inputs is None:
return None
@@ -399,34 +418,36 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
seq_len, dim = audio_embedding.shape
# Pad such that seq_len is divisible by downsample_factor
target_seq_len = self.downsample_factor * math.ceil(
seq_len / self.downsample_factor)
seq_len / self.downsample_factor
)
audio_embedding = torch.nn.functional.pad(
audio_embedding,
(0, 0, 0, target_seq_len - seq_len),
)
audio_embeddings[i] = audio_embedding.reshape(
target_seq_len // self.downsample_factor,
dim * self.downsample_factor)
target_seq_len // self.downsample_factor, dim * self.downsample_factor
)
# Concat, project and resplit
audio_embeddings_packed = torch.cat(audio_embeddings, dim=0)
audio_embeddings_packed = self.audio_language_adapter(
audio_embeddings_packed)
audio_embeddings = torch.split(audio_embeddings_packed,
[a.shape[0] for a in audio_embeddings],
dim=0)
audio_embeddings_packed = self.audio_language_adapter(audio_embeddings_packed)
audio_embeddings = torch.split(
audio_embeddings_packed, [a.shape[0] for a in audio_embeddings], dim=0
)
return audio_embeddings
def _parse_and_validate_audio_arrays(
self, **kwargs: object) -> Union[list[torch.Tensor], None]:
self, **kwargs: object
) -> Union[list[torch.Tensor], None]:
audio_arrays = kwargs.pop("audio_arrays", None)
if audio_arrays is None:
return None
if not isinstance(audio_arrays, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_arrays. "
f"Got type: {type(audio_arrays)}")
raise ValueError(
f"Incorrect type of audio_arrays. Got type: {type(audio_arrays)}"
)
audio_arrays = flatten_bn(audio_arrays)
if isinstance(audio_arrays, torch.Tensor):
@@ -440,8 +461,9 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states)
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
tokenizer = cached_tokenizer_from_config(model_config)
audio_config = tokenizer.instruct.audio_encoder.audio_config
max_audio_clip_s = audio_config.chunk_length_s
@@ -455,19 +477,23 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
@classmethod
# for speech-to-text transcription
def get_generation_prompt(cls, audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig,
stt_config: SpeechToTextConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str],
) -> PromptType:
tokenizer = cached_tokenizer_from_config(model_config)
audio = Audio(audio, int(stt_config.sample_rate),
format="wav") # lossless
req = TranscriptionRequest(model=model_config.model,
audio=RawAudio.from_audio(audio),
language=language)
audio = Audio(audio, int(stt_config.sample_rate), format="wav") # lossless
req = TranscriptionRequest(
model=model_config.model,
audio=RawAudio.from_audio(audio),
language=language,
)
tokenized = tokenizer.instruct.encode_transcription(req)
audio = (tokenized.audios[0].audio_array, stt_config.sample_rate)
@@ -476,21 +502,24 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return cast(PromptType, prompts_dict)
@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig) -> Optional[int]:
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> Optional[int]:
"""
Map from audio duration to number of audio tokens produced by the ASR
Map from audio duration to number of audio tokens produced by the ASR
model, without running a forward pass.
This is used for estimating the amount of processing for this audio.
"""
tokenizer = cached_tokenizer_from_config(model_config)
adapter = VoxtralProcessorAdapter(tokenizer)
return adapter.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate))
int(audio_duration_s * stt_config.sample_rate)
)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# fmt: off
remapping_rules = [
(r"mm_whisper_embeddings\.(.*)", r"\1"),
@@ -501,10 +530,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
# fmt: on
audio_params = dict(
nn.ModuleDict({
"audio_language_adapter":
self.audio_language_adapter,
}).named_parameters())
nn.ModuleDict(
{
"audio_language_adapter": self.audio_language_adapter,
}
).named_parameters()
)
loaded_weights = set()
@@ -512,10 +543,12 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
nonlocal loaded_weights
for name, w in weights:
is_encoder = (
name.startswith("mm_whisper_embeddings") and
not name.startswith("mm_whisper_embeddings.tok_embeddings")
name.startswith("mm_whisper_embeddings")
and not name.startswith("mm_whisper_embeddings.tok_embeddings")
and not name.startswith(
"mm_whisper_embeddings.audio_language_projection"))
"mm_whisper_embeddings.audio_language_projection"
)
)
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
@@ -546,7 +579,8 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
return loaded_weights
def maybe_update_quant_config(
self, quant_config: QuantizationConfig) -> QuantizationConfig:
self, quant_config: QuantizationConfig
) -> QuantizationConfig:
"""
Update quant config to so that ignored module and target module names
match the vLLM model names.
@@ -555,32 +589,54 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
"""
remapping_rules = [
(r"output", r"language_model.lm_head"),
(r"layers\.(\d+)\.attention\.wo",
r"language_model.model.layers.\1.self_attn.out_proj"),
(r"layers\.(\d+)\.attention\.w(.*)",
r"language_model.model.layers.\1.self_attn.\2_proj"),
(r"layers\.(\d+)\.feed_forward\.w1",
r"language_model.model.layers.\1.mlp.gate_proj"),
(r"layers\.(\d+)\.feed_forward\.w2",
r"language_model.model.layers.\1.mlp.down_proj"),
(r"layers\.(\d+)\.feed_forward\.w3",
r"language_model.model.layers.\1.mlp.up_proj"),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
r"whisper_encoder.whisper_encoder.conv1"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
r"whisper_encoder.whisper_encoder.conv2"),
(r"mm_whisper_embeddings\.audio_language_projection\.0",
r"audio_language_adapter.w_in"),
(r"mm_whisper_embeddings\.audio_language_projection\.2",
r"audio_language_adapter.w_out"),
(
r"layers\.(\d+)\.attention\.wo",
r"language_model.model.layers.\1.self_attn.out_proj",
),
(
r"layers\.(\d+)\.attention\.w(.*)",
r"language_model.model.layers.\1.self_attn.\2_proj",
),
(
r"layers\.(\d+)\.feed_forward\.w1",
r"language_model.model.layers.\1.mlp.gate_proj",
),
(
r"layers\.(\d+)\.feed_forward\.w2",
r"language_model.model.layers.\1.mlp.down_proj",
),
(
r"layers\.(\d+)\.feed_forward\.w3",
r"language_model.model.layers.\1.mlp.up_proj",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
r"whisper_encoder.whisper_encoder.conv1",
),
(
r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
r"whisper_encoder.whisper_encoder.conv2",
),
(
r"mm_whisper_embeddings\.audio_language_projection\.0",
r"audio_language_adapter.w_in",
),
(
r"mm_whisper_embeddings\.audio_language_projection\.2",
r"audio_language_adapter.w_out",
),
]
# Update ignore list
@@ -613,7 +669,6 @@ class VoxtralForConditionalGeneration(nn.Module, SupportsMultiModal,
class AudioLanguageAdapter(nn.Module):
def __init__(self, hidden_size: int, dim: int) -> None:
super().__init__()
self.w_in = nn.Linear(hidden_size, dim, bias=False)
@@ -650,10 +705,11 @@ class VoxtralEncoderModel(nn.Module):
super().__init__()
self.config = cast(WhisperConfig, vllm_config.model_config.hf_config)
self.dtype: torch.dtype = vllm_config.model_config.dtype
self.whisper_encoder = WhisperEncoder(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "whisper_encoder"),
init_in_fp32=True)
self.whisper_encoder = WhisperEncoder(
vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "whisper_encoder"),
init_in_fp32=True,
)
mel_filters = mel_filter_bank(
num_frequency_bins=1 + self.config.window_size // 2,
num_mel_bins=self.config.num_mel_bins,
@@ -668,8 +724,7 @@ class VoxtralEncoderModel(nn.Module):
audio_waveforms: torch.Tensor,
) -> torch.Tensor:
input_dtype = audio_waveforms.dtype
window = torch.hann_window(self.config.window_size).to(
audio_waveforms.device)
window = torch.hann_window(self.config.window_size).to(audio_waveforms.device)
stft = torch.stft(
audio_waveforms,
self.config.window_size,
@@ -677,7 +732,7 @@ class VoxtralEncoderModel(nn.Module):
window=window,
return_complex=True,
)
magnitudes = stft[..., :-1].abs()**2
magnitudes = stft[..., :-1].abs() ** 2
mel_spec = self.mel_filters.T @ magnitudes
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
@@ -686,8 +741,9 @@ class VoxtralEncoderModel(nn.Module):
@property
def downsample_factor(self) -> int:
return self.whisper_encoder.conv1.stride[
0] * self.whisper_encoder.conv2.stride[0]
return (
self.whisper_encoder.conv1.stride[0] * self.whisper_encoder.conv2.stride[0]
)
@property
def chunk_size(self) -> int:
@@ -721,8 +777,7 @@ class VoxtralEncoderModel(nn.Module):
input_features = [input_features]
# Split long inputs into chunks
input_embeds, chunks_per_example = (
self.prepare_inputs_for_conv(input_features))
input_embeds, chunks_per_example = self.prepare_inputs_for_conv(input_features)
# [total_num_chunks, ceil(chunk_size / downsample_factor), hidden_size]
out = self.whisper_encoder([input_embeds])
@@ -731,7 +786,7 @@ class VoxtralEncoderModel(nn.Module):
chunk_idx = 0
results = []
for n_chunks in chunks_per_example:
result = out[chunk_idx:chunk_idx + n_chunks].flatten(0, 1)
result = out[chunk_idx : chunk_idx + n_chunks].flatten(0, 1)
results.append(result)
chunk_idx += n_chunks
@@ -751,7 +806,7 @@ class VoxtralEncoderModel(nn.Module):
if re.fullmatch(pattern, name):
name = re.sub(pattern, repl, name)
for (param_name, weight_name, shard_id) in stacked_params_mapping:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
@@ -762,8 +817,7 @@ class VoxtralEncoderModel(nn.Module):
break
else:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
return name