Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -9,43 +9,58 @@ from typing import Annotated, Literal, Optional, Union, cast
import numpy as np
import torch
from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers import (
BatchFeature,
WhisperConfig,
WhisperFeatureExtractor,
WhisperProcessor,
)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.config import (CacheConfig, ModelConfig, SpeechToTextConfig,
VllmConfig)
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement, PromptUpdate)
from vllm.multimodal.processing import (
BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement,
PromptUpdate,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.transformers_utils.processor import cached_get_processor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (MultiModalEmbeddings, SupportsMultiModal,
SupportsTranscription)
from .utils import (AutoWeightsLoader, WeightsMapper, cast_overflow_tensors,
make_layers, maybe_prefix)
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsTranscription
from .utils import (
AutoWeightsLoader,
WeightsMapper,
cast_overflow_tensors,
make_layers,
maybe_prefix,
)
logger = init_logger(__name__)
@@ -108,7 +123,7 @@ ISO639_1_SUPPORTED_LANGS = {
"uk": "Ukrainian",
"ur": "Urdu",
"vi": "Vietnamese",
"cy": "Welsh"
"cy": "Welsh",
}
@@ -120,8 +135,7 @@ class WhisperAudioInputs(TensorSchema):
- t: Time frames (M)
"""
input_features: Annotated[Optional[NestedTensors],
TensorShape("b", "nmb", "t")]
input_features: Annotated[Optional[NestedTensors], TensorShape("b", "nmb", "t")]
class WhisperEncoderAttention(MultiHeadAttention):
@@ -153,7 +167,6 @@ class WhisperEncoderAttention(MultiHeadAttention):
class WhisperPositionalEmbedding(nn.Embedding):
def __init__(self, num_positions: int, embedding_dim: int):
super().__init__(num_positions, embedding_dim)
@@ -162,7 +175,6 @@ class WhisperPositionalEmbedding(nn.Embedding):
class WhisperAttention(nn.Module):
def __init__(
self,
embed_dim: int,
@@ -196,7 +208,8 @@ class WhisperAttention(nn.Module):
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
f"{self.embed_dim} and `num_heads`: {num_heads}).")
f"{self.embed_dim} and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
@@ -269,7 +282,6 @@ class WhisperAttention(nn.Module):
class WhisperCrossAttention(WhisperAttention):
def __init__(
self,
embed_dim: int,
@@ -336,7 +348,6 @@ class WhisperCrossAttention(WhisperAttention):
class WhisperMLP(nn.Module):
def __init__(
self,
embed_dim: int,
@@ -369,7 +380,6 @@ class WhisperMLP(nn.Module):
class WhisperEncoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -414,7 +424,6 @@ class WhisperEncoderLayer(nn.Module):
class WhisperDecoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -474,48 +483,39 @@ class WhisperDecoderLayer(nn.Module):
class WhisperEncoder(nn.Module):
def __init__(self,
*,
vllm_config: VllmConfig,
prefix: str = "",
init_in_fp32: bool = False):
def __init__(
self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False
):
super().__init__()
config = vllm_config.model_config.hf_config
embed_dim = config.d_model
self.num_mel_bins = config.num_mel_bins
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(embed_dim)
if config.scale_embedding else 1.0)
self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0
self.conv1 = nn.Conv1d(self.num_mel_bins,
embed_dim,
kernel_size=3,
padding=1)
self.conv2 = nn.Conv1d(embed_dim,
embed_dim,
kernel_size=3,
stride=2,
padding=1)
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
self.start_layer, self.end_layer, self.layers = make_layers(
config.encoder_layers,
lambda prefix: WhisperEncoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers"),
lambda prefix: WhisperEncoderLayer(
vllm_config=vllm_config, prefix=f"{prefix}.layers"
),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
maybe_fp32_init_ctx = set_default_torch_dtype(
torch.float32) if init_in_fp32 else nullcontext()
maybe_fp32_init_ctx = (
set_default_torch_dtype(torch.float32) if init_in_fp32 else nullcontext()
)
with (
torch.no_grad(),
maybe_fp32_init_ctx,
torch.no_grad(),
maybe_fp32_init_ctx,
):
self.embed_positions = nn.Embedding(self.max_source_positions,
embed_dim)
self.embed_positions = nn.Embedding(self.max_source_positions, embed_dim)
self.embed_positions.weight.copy_(
sinusoids(*self.embed_positions.weight.shape))
sinusoids(*self.embed_positions.weight.shape)
)
def forward(self, input_features: Union[torch.Tensor, list[torch.Tensor]]):
hidden_states = []
@@ -523,9 +523,9 @@ class WhisperEncoder(nn.Module):
embeds = nn.functional.gelu(self.conv1(features))
embeds = nn.functional.gelu(self.conv2(embeds))
embeds = embeds.transpose(-1, -2)
embeds = (embeds +
self.embed_positions.weight[:embeds.size(-2), :]).to(
embeds.dtype)
embeds = (embeds + self.embed_positions.weight[: embeds.size(-2), :]).to(
embeds.dtype
)
hidden_states.append(embeds)
hidden_states = torch.cat(hidden_states)
@@ -537,7 +537,6 @@ class WhisperEncoder(nn.Module):
class WhisperDecoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
@@ -545,17 +544,19 @@ class WhisperDecoder(nn.Module):
self.padding_idx = config.pad_token_id
self.max_target_positions = config.max_target_positions
self.max_source_positions = config.max_source_positions
self.embed_scale = (math.sqrt(config.d_model)
if config.scale_embedding else 1.0)
self.embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model,
self.padding_idx)
self.embed_tokens = nn.Embedding(
config.vocab_size, config.d_model, self.padding_idx
)
self.embed_positions = WhisperPositionalEmbedding(
self.max_target_positions, config.d_model)
self.max_target_positions, config.d_model
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.decoder_layers,
lambda prefix: WhisperDecoderLayer(vllm_config=vllm_config,
prefix=f"{prefix}.layers"),
lambda prefix: WhisperDecoderLayer(
vllm_config=vllm_config, prefix=f"{prefix}.layers"
),
prefix=f"{prefix}.layers",
)
self.layer_norm = nn.LayerNorm(config.d_model)
@@ -584,13 +585,14 @@ class WhisperDecoder(nn.Module):
class WhisperModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.encoder = WhisperEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
self.decoder = WhisperDecoder(vllm_config=vllm_config,
prefix=f"{prefix}.decoder")
self.encoder = WhisperEncoder(
vllm_config=vllm_config, prefix=f"{prefix}.encoder"
)
self.decoder = WhisperDecoder(
vllm_config=vllm_config, prefix=f"{prefix}.decoder"
)
def forward(
self,
@@ -614,8 +616,7 @@ class WhisperModel(nn.Module):
return None
return self.encoder(input_features)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".self_attn.qkv_proj", ".self_attn.q_proj", "q"),
@@ -645,15 +646,13 @@ class WhisperModel(nn.Module):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
@@ -670,8 +669,7 @@ class WhisperProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
def get_feature_extractor(self,
**kwargs: object) -> WhisperFeatureExtractor:
def get_feature_extractor(self, **kwargs: object) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
@@ -682,7 +680,6 @@ class WhisperProcessingInfo(BaseProcessingInfo):
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
@@ -703,16 +700,13 @@ class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio":
self._get_dummy_audios(length=audio_len,
num_audios=num_audios,
overrides=audio_overrides)
"audio": self._get_dummy_audios(
length=audio_len, num_audios=num_audios, overrides=audio_overrides
)
}
class WhisperMultiModalProcessor(
EncDecMultiModalProcessor[WhisperProcessingInfo]):
class WhisperMultiModalProcessor(EncDecMultiModalProcessor[WhisperProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
@@ -779,11 +773,14 @@ class WhisperMultiModalProcessor(
]
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal):
@MULTIMODAL_REGISTRY.register_processor(
WhisperMultiModalProcessor,
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder,
)
class WhisperForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
@@ -793,10 +790,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(orig_to_new_substr={
".fc1.": ".mlp.fc1.",
".fc2.": ".mlp.fc2."
})
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
)
# Whisper only supports audio-conditioned generation.
supports_transcription_only = True
@@ -811,23 +807,26 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
logger.warning(
"Defaulting to language='en'. If you wish to transcribe "
"audio in a different language, pass the `language` field "
"in the TranscriptionRequest.")
"in the TranscriptionRequest."
)
language = "en"
return super().validate_language(language)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str]) -> PromptType:
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: Optional[str],
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: Optional[str],
) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the Whisper prompt")
"Language must be specified when creating the Whisper prompt"
)
prompt = {
"encoder_prompt": {
# Whisper does not support encoder prompt.
@@ -836,10 +835,11 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
"audio": (audio, stt_config.sample_rate),
},
},
"decoder_prompt":
((f"<|prev|>{request_prompt}" if request_prompt else "") +
f"<|startoftranscript|><|{language}|>" +
f"<|{task_type}|><|notimestamps|>")
"decoder_prompt": (
(f"<|prev|>{request_prompt}" if request_prompt else "")
+ f"<|startoftranscript|><|{language}|>"
+ f"<|{task_type}|><|notimestamps|>"
),
}
return cast(PromptType, prompt)
@@ -851,8 +851,9 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
raise ValueError("Only audio modality is supported")
@classmethod
def get_speech_to_text_config(cls, model_config: ModelConfig,
task_type: str) -> SpeechToTextConfig:
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
processor = cached_get_processor(model_config.model)
return SpeechToTextConfig(
@@ -861,9 +862,12 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
)
@classmethod
def get_num_audio_tokens(cls, audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig) -> Optional[int]:
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> Optional[int]:
processor = cached_get_processor(model_config.model)
hop_length = processor.feature_extractor.hop_length
assert hop_length is not None
@@ -871,8 +875,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
# prompts directly at least not to Whisper.
# One indicator of the encoder amount of processing
# is the log-mel spectogram length.
return math.ceil(audio_duration_s * stt_config.sample_rate /
hop_length)
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -883,15 +886,17 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
self.model = WhisperModel(vllm_config=vllm_config, prefix=prefix)
self.unpadded_vocab_size = config.vocab_size
self.proj_out = ParallelLMHead(config.vocab_size,
config.d_model,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "proj_out"))
self.proj_out = self.proj_out.tie_weights(
self.model.decoder.embed_tokens)
self.proj_out = ParallelLMHead(
config.vocab_size,
config.d_model,
quant_config=quant_config,
prefix=maybe_prefix(prefix, "proj_out"),
)
self.proj_out = self.proj_out.tie_weights(self.model.decoder.embed_tokens)
logit_scale = getattr(config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.logits_processor = LogitsProcessor(
self.unpadded_vocab_size, config.vocab_size, logit_scale
)
def forward(
self,
@@ -910,8 +915,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def get_language_model(self) -> torch.nn.Module:
return self.model.decoder
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface.
audio_input = self._parse_and_validate_audio_input(**kwargs)
return [self.model.get_encoder_outputs(audio_input["input_features"])]
@@ -928,16 +932,16 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
# Whisper does not have encoder text tokens.
return self.model.decoder.get_input_embeddings(input_ids)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> WhisperAudioInputs:
def _parse_and_validate_audio_input(self, **kwargs: object) -> WhisperAudioInputs:
input_features = kwargs.pop("input_features", None)
if input_features is not None:
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(input_features)}")
input_features = torch.cat(
[feat.to(self.dtype) for feat in input_features])
raise ValueError(
"Incorrect type of audio features. "
f"Got type: {type(input_features)}"
)
input_features = torch.cat([feat.to(self.dtype) for feat in input_features])
return WhisperAudioInputs(input_features=input_features)
@@ -945,8 +949,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
logits = self.logits_processor(self.proj_out, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_prefixes=["proj_out."])
# add fake zeros bias for k_proj to state_dict
@@ -955,7 +958,7 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
def _create_fake_bias_for_k_proj(
weights: Iterable[tuple[str, torch.Tensor]]
weights: Iterable[tuple[str, torch.Tensor]],
) -> Iterable[tuple[str, torch.Tensor]]:
"""
Create full zeros bias for k_proj weight in self-attn and x-attn layers.