Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -39,17 +39,26 @@ from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems)
from vllm.multimodal.inputs import (
MultiModalDataDict,
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.processing import (
BaseMultiModalProcessor,
BaseProcessingInfo,
PromptReplacement,
PromptUpdate,
PromptUpdateDetails,
)
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.midashenglm import DashengConfig
@@ -63,7 +72,8 @@ _Tuple2 = Union[int, tuple[int, int], Sequence[int]]
def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]:
if isinstance(x, collections.abc.Sequence):
assert len(x) == 2, (
f"Expected a sequence of length 2, got {x} with length {len(x)}")
f"Expected a sequence of length 2, got {x} with length {len(x)}"
)
return cast(tuple[int, int], tuple(x))
return (x, x)
@@ -80,12 +90,14 @@ def calculate_mel_frames_dasheng(
if center:
audio_length_samples = audio_length_samples + n_fft
return (int(1 + ((audio_length_samples - n_fft) / hop_size)) //
dasheng_subsampling // model_subsampling)
return (
int(1 + ((audio_length_samples - n_fft) / hop_size))
// dasheng_subsampling
// model_subsampling
)
class AudioPatchEmbed(nn.Module):
def __init__(
self,
input_size: _Tuple2 = 64,
@@ -118,14 +130,14 @@ class AudioPatchEmbed(nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
if self.flatten:
x = torch.permute(torch.flatten(
x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c")
x = torch.permute(
torch.flatten(x, 2, 3), (0, 2, 1)
) # rearrange(x, "b c f t -> b (f t) c")
x = self.norm(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
@@ -136,7 +148,6 @@ class LayerScale(nn.Module):
class DashengMlp(nn.Module):
def __init__(
self,
in_features: int,
@@ -170,7 +181,6 @@ class DashengMlp(nn.Module):
class DashengAttention(nn.Module):
def __init__(
self,
dim: int,
@@ -237,7 +247,6 @@ class DashengAttention(nn.Module):
class DashengBlock(nn.Module):
def __init__(
self,
dim: int,
@@ -257,8 +266,9 @@ class DashengBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.ls1 = (LayerScale(dim, init_values=init_values)
if init_values else nn.Identity())
self.ls1 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
self.mlp = DashengMlp(
@@ -267,8 +277,9 @@ class DashengBlock(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.ls2 = (LayerScale(dim, init_values=init_values)
if init_values else nn.Identity())
self.ls2 = (
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
)
# Kwargs usually has a mask parameter that is passed to Attention
def forward(
@@ -282,7 +293,6 @@ class DashengBlock(nn.Module):
class DashengFrontend(nn.Module):
def __init__(self, config: DashengConfig):
super().__init__()
self.config = config
@@ -302,9 +312,7 @@ class DashengFrontend(nn.Module):
n_mels=self.config.n_mels,
sample_rate=self.config.sample_rate,
)
self.register_buffer("melscale_fbanks",
melscale_fbanks,
persistent=False)
self.register_buffer("melscale_fbanks", melscale_fbanks, persistent=False)
self.melscale_fbanks: torch.Tensor
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
@@ -319,8 +327,7 @@ class DashengFrontend(nn.Module):
normalized=False,
center=self.config.center,
)
mel_spectrogram = (
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
mel_spectrogram = (spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
# x has shape [batch, freq, time].
# F.amplitude_to_DB accepts inputs shaped as:
# - [freq, time]
@@ -339,7 +346,6 @@ class DashengFrontend(nn.Module):
class DashengAudioTransformer(nn.Module):
def __init__(
self,
config: DashengConfig,
@@ -365,9 +371,11 @@ class DashengAudioTransformer(nn.Module):
)
self.time_pos_embed = nn.Parameter(
torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1]))
torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1])
)
self.freq_pos_embed = nn.Parameter(
torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1))
torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1)
)
self.blocks = nn.ModuleList(
DashengBlock(
dim=config.embed_dim,
@@ -377,7 +385,9 @@ class DashengAudioTransformer(nn.Module):
init_values=config.init_values,
quant_config=quant_config,
prefix=f"{prefix}.blocks.{i}",
) for i in range(config.depth))
)
for i in range(config.depth)
)
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
def forward_features(
@@ -387,10 +397,12 @@ class DashengAudioTransformer(nn.Module):
) -> torch.Tensor:
t = x.shape[-1]
x = x + self.time_pos_embed[:, :, :, :t]
x = (x + self.freq_pos_embed[:, :, :, :]
) # Just to support __getitem__ in posembed
x = torch.permute(torch.flatten(x, 2, 3),
(0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c")
x = (
x + self.freq_pos_embed[:, :, :, :]
) # Just to support __getitem__ in posembed
x = torch.permute(
torch.flatten(x, 2, 3), (0, 2, 1)
) # rearrange(x, "b c f t -> b (f t) c")
for block in self.blocks:
x = block(x, mask)
x = self.norm(x)
@@ -423,7 +435,8 @@ class DashengAudioTransformer(nn.Module):
if x_length is not None:
assert len(x_length) == len(x), (
"batchsizes of input x and x_length need to be same")
"batchsizes of input x and x_length need to be same"
)
assert x_length.ndim == 1, "Lengths are of size (B,)"
scaled_lengths = (x_length / (self.hop_length * 4)).long()
mask = self._to_mask(max_length=t, lengths=scaled_lengths)
@@ -444,7 +457,6 @@ class DashengAudioTransformer(nn.Module):
class AudioProjectorSubsample(nn.Module):
def __init__(
self,
in_dim: int,
@@ -483,13 +495,14 @@ class AudioProjectorSubsample(nn.Module):
mask = mask[:, :-num_frames_to_discard]
if mask is None:
mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device)
x = x.reshape(batch_size, -1, self.k *
dim) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k)
x = x.reshape(
batch_size, -1, self.k * dim
) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k)
for layer in self.net:
x = layer(x)
mask = mask.reshape(
batch_size, -1,
self.k) # rearrange(mask, "b (s k) -> b s k", k=self.k)
batch_size, -1, self.k
) # rearrange(mask, "b (s k) -> b s k", k=self.k)
mask = mask.any(dim=-1).long()
return x, mask
@@ -503,7 +516,6 @@ class MiDashengLMAudioInputs(TypedDict):
class MiDashengLMProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
@@ -522,9 +534,7 @@ class MiDashengLMProcessingInfo(BaseProcessingInfo):
return 160000
class MiDashengLMDummyInputsBuilder(
BaseDummyInputsBuilder[MiDashengLMProcessingInfo]):
class MiDashengLMDummyInputsBuilder(BaseDummyInputsBuilder[MiDashengLMProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
@@ -547,16 +557,17 @@ class MiDashengLMDummyInputsBuilder(
audio_overrides = mm_options.get("audio") if mm_options else None
return {
"audio":
self._get_dummy_audios(length=self.info.get_max_audio_len(),
num_audios=num_audios,
overrides=audio_overrides)
"audio": self._get_dummy_audios(
length=self.info.get_max_audio_len(),
num_audios=num_audios,
overrides=audio_overrides,
)
}
class MiDashengLMMultiModalProcessor(
BaseMultiModalProcessor[MiDashengLMProcessingInfo]):
BaseMultiModalProcessor[MiDashengLMProcessingInfo]
):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
@@ -578,8 +589,10 @@ class MiDashengLMMultiModalProcessor(
(0, min_audio_len - audio.shape[-1]),
mode="constant",
constant_values=0,
) if isinstance(audio, np.ndarray)
and audio.shape[-1] < min_audio_len else audio for audio in audios
)
if isinstance(audio, np.ndarray) and audio.shape[-1] < min_audio_len
else audio
for audio in audios
]
if processed_audios:
@@ -590,7 +603,9 @@ class MiDashengLMMultiModalProcessor(
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
mm_kwargs = dict(**mm_kwargs, )
mm_kwargs = dict(
**mm_kwargs,
)
return super()._call_hf_processor(
prompt=prompt,
@@ -627,11 +642,13 @@ class MiDashengLMMultiModalProcessor(
if audio_length is None:
audio_output_lengths = []
else:
audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length)
audio_length_np = (
audio_length.cpu().numpy()
if isinstance(audio_length, torch.Tensor)
else audio_length
)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
for length in audio_length_np
]
@@ -708,22 +725,23 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
self.quant_config = quant_config
self.make_empty_intermediate_tensors = (
self.decoder.make_empty_intermediate_tensors)
self.decoder.make_empty_intermediate_tensors
)
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
def _validate_and_reshape_mm_tensor(
self, mm_input: object, name: str
) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of {name}. Got type: {type(mm_input)}")
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:])
if name == "input_values":
max_length = max(tensor.shape[1] for tensor in mm_input)
padded_mm_input = [
torch.nn.functional.pad(tensor,
(0, max_length - tensor.shape[1]))
if tensor.shape[1] < max_length else tensor
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
if tensor.shape[1] < max_length
else tensor
for tensor in mm_input
]
return torch.concat(padded_mm_input)
@@ -731,65 +749,67 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
return torch.concat(mm_input)
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[MiDashengLMAudioInputs]:
self, **kwargs: object
) -> Optional[MiDashengLMAudioInputs]:
input_values = kwargs.pop("input_values", None)
audio_length = kwargs.pop("audio_length", None)
if input_values is None:
return None
input_values = self._validate_and_reshape_mm_tensor(
input_values, "input_values")
input_values, "input_values"
)
audio_length = self._validate_and_reshape_mm_tensor(
audio_length, "audio_length")
audio_length, "audio_length"
)
if not isinstance(input_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio input features. "
f"Got type: {type(input_values)}")
raise ValueError(
"Incorrect type of audio input features. "
f"Got type: {type(input_values)}"
)
return MiDashengLMAudioInputs(
input_values=input_values,
audio_length=audio_length,
)
def _process_audio_input(
self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
# Process audio through encoder and projector
input_values = audio_input["input_values"]
audio_length = audio_input["audio_length"]
encoder_out, encoder_atts = self.audio_encoder(input_values,
audio_length)
encoder_out, encoder_atts = self.audio_encoder(input_values, audio_length)
audio_embeddings, _ = self.audio_projector(encoder_out, encoder_atts)
audio_embeddings = audio_embeddings.to(
audio_input["input_values"].dtype)
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length)
audio_length_np = (
audio_length.cpu().numpy()
if isinstance(audio_length, torch.Tensor)
else audio_length
)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
for length in audio_length_np
]
audio_output_lengths = torch.tensor(audio_output_lengths).to(
audio_embeddings.device)
audio_embeddings.device
)
audio_feature_mask = torch.arange(
max_audio_tokens,
device=audio_embeddings.device).unsqueeze(0).expand(
batch_size,
max_audio_tokens) < audio_output_lengths.unsqueeze(1)
max_audio_tokens, device=audio_embeddings.device
).unsqueeze(0).expand(
batch_size, max_audio_tokens
) < audio_output_lengths.unsqueeze(1)
masked_audio_features = audio_embeddings[audio_feature_mask].view(
-1, embed_dim)
masked_audio_features = audio_embeddings[audio_feature_mask].view(-1, embed_dim)
return torch.split(masked_audio_features,
audio_output_lengths.tolist())
return torch.split(masked_audio_features, audio_output_lengths.tolist())
def get_language_model(self) -> torch.nn.Module:
return self.decoder
def get_multimodal_embeddings(self,
**kwargs: object) -> MultiModalEmbeddings:
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
audio_input = self._parse_and_validate_audio_input(**kwargs)
if audio_input is None:
@@ -828,7 +848,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> Optional[torch.Tensor]:
return self.decoder.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)