Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -39,17 +39,26 @@ from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalKwargsItems)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptUpdate, PromptUpdateDetails)
|
||||
from vllm.multimodal.processing import (
|
||||
BaseMultiModalProcessor,
|
||||
BaseProcessingInfo,
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.midashenglm import DashengConfig
|
||||
@@ -63,7 +72,8 @@ _Tuple2 = Union[int, tuple[int, int], Sequence[int]]
|
||||
def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]:
|
||||
if isinstance(x, collections.abc.Sequence):
|
||||
assert len(x) == 2, (
|
||||
f"Expected a sequence of length 2, got {x} with length {len(x)}")
|
||||
f"Expected a sequence of length 2, got {x} with length {len(x)}"
|
||||
)
|
||||
return cast(tuple[int, int], tuple(x))
|
||||
return (x, x)
|
||||
|
||||
@@ -80,12 +90,14 @@ def calculate_mel_frames_dasheng(
|
||||
if center:
|
||||
audio_length_samples = audio_length_samples + n_fft
|
||||
|
||||
return (int(1 + ((audio_length_samples - n_fft) / hop_size)) //
|
||||
dasheng_subsampling // model_subsampling)
|
||||
return (
|
||||
int(1 + ((audio_length_samples - n_fft) / hop_size))
|
||||
// dasheng_subsampling
|
||||
// model_subsampling
|
||||
)
|
||||
|
||||
|
||||
class AudioPatchEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: _Tuple2 = 64,
|
||||
@@ -118,14 +130,14 @@ class AudioPatchEmbed(nn.Module):
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.proj(x)
|
||||
if self.flatten:
|
||||
x = torch.permute(torch.flatten(
|
||||
x, 2, 3), (0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c")
|
||||
x = torch.permute(
|
||||
torch.flatten(x, 2, 3), (0, 2, 1)
|
||||
) # rearrange(x, "b c f t -> b (f t) c")
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
|
||||
def __init__(self, dim, init_values=1e-5, inplace=False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
@@ -136,7 +148,6 @@ class LayerScale(nn.Module):
|
||||
|
||||
|
||||
class DashengMlp(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
@@ -170,7 +181,6 @@ class DashengMlp(nn.Module):
|
||||
|
||||
|
||||
class DashengAttention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -237,7 +247,6 @@ class DashengAttention(nn.Module):
|
||||
|
||||
|
||||
class DashengBlock(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
@@ -257,8 +266,9 @@ class DashengBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.ls1 = (LayerScale(dim, init_values=init_values)
|
||||
if init_values else nn.Identity())
|
||||
self.ls1 = (
|
||||
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.mlp = DashengMlp(
|
||||
@@ -267,8 +277,9 @@ class DashengBlock(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
self.ls2 = (LayerScale(dim, init_values=init_values)
|
||||
if init_values else nn.Identity())
|
||||
self.ls2 = (
|
||||
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
||||
)
|
||||
|
||||
# Kwargs usually has a mask parameter that is passed to Attention
|
||||
def forward(
|
||||
@@ -282,7 +293,6 @@ class DashengBlock(nn.Module):
|
||||
|
||||
|
||||
class DashengFrontend(nn.Module):
|
||||
|
||||
def __init__(self, config: DashengConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -302,9 +312,7 @@ class DashengFrontend(nn.Module):
|
||||
n_mels=self.config.n_mels,
|
||||
sample_rate=self.config.sample_rate,
|
||||
)
|
||||
self.register_buffer("melscale_fbanks",
|
||||
melscale_fbanks,
|
||||
persistent=False)
|
||||
self.register_buffer("melscale_fbanks", melscale_fbanks, persistent=False)
|
||||
self.melscale_fbanks: torch.Tensor
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
||||
@@ -319,8 +327,7 @@ class DashengFrontend(nn.Module):
|
||||
normalized=False,
|
||||
center=self.config.center,
|
||||
)
|
||||
mel_spectrogram = (
|
||||
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
|
||||
mel_spectrogram = (spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
|
||||
# x has shape [batch, freq, time].
|
||||
# F.amplitude_to_DB accepts inputs shaped as:
|
||||
# - [freq, time]
|
||||
@@ -339,7 +346,6 @@ class DashengFrontend(nn.Module):
|
||||
|
||||
|
||||
class DashengAudioTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: DashengConfig,
|
||||
@@ -365,9 +371,11 @@ class DashengAudioTransformer(nn.Module):
|
||||
)
|
||||
|
||||
self.time_pos_embed = nn.Parameter(
|
||||
torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1]))
|
||||
torch.empty(1, config.embed_dim, 1, self.patch_embed.grid_size[1])
|
||||
)
|
||||
self.freq_pos_embed = nn.Parameter(
|
||||
torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1))
|
||||
torch.empty(1, config.embed_dim, self.patch_embed.grid_size[0], 1)
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
DashengBlock(
|
||||
dim=config.embed_dim,
|
||||
@@ -377,7 +385,9 @@ class DashengAudioTransformer(nn.Module):
|
||||
init_values=config.init_values,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.blocks.{i}",
|
||||
) for i in range(config.depth))
|
||||
)
|
||||
for i in range(config.depth)
|
||||
)
|
||||
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
||||
|
||||
def forward_features(
|
||||
@@ -387,10 +397,12 @@ class DashengAudioTransformer(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
t = x.shape[-1]
|
||||
x = x + self.time_pos_embed[:, :, :, :t]
|
||||
x = (x + self.freq_pos_embed[:, :, :, :]
|
||||
) # Just to support __getitem__ in posembed
|
||||
x = torch.permute(torch.flatten(x, 2, 3),
|
||||
(0, 2, 1)) # rearrange(x, "b c f t -> b (f t) c")
|
||||
x = (
|
||||
x + self.freq_pos_embed[:, :, :, :]
|
||||
) # Just to support __getitem__ in posembed
|
||||
x = torch.permute(
|
||||
torch.flatten(x, 2, 3), (0, 2, 1)
|
||||
) # rearrange(x, "b c f t -> b (f t) c")
|
||||
for block in self.blocks:
|
||||
x = block(x, mask)
|
||||
x = self.norm(x)
|
||||
@@ -423,7 +435,8 @@ class DashengAudioTransformer(nn.Module):
|
||||
|
||||
if x_length is not None:
|
||||
assert len(x_length) == len(x), (
|
||||
"batchsizes of input x and x_length need to be same")
|
||||
"batchsizes of input x and x_length need to be same"
|
||||
)
|
||||
assert x_length.ndim == 1, "Lengths are of size (B,)"
|
||||
scaled_lengths = (x_length / (self.hop_length * 4)).long()
|
||||
mask = self._to_mask(max_length=t, lengths=scaled_lengths)
|
||||
@@ -444,7 +457,6 @@ class DashengAudioTransformer(nn.Module):
|
||||
|
||||
|
||||
class AudioProjectorSubsample(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim: int,
|
||||
@@ -483,13 +495,14 @@ class AudioProjectorSubsample(nn.Module):
|
||||
mask = mask[:, :-num_frames_to_discard]
|
||||
if mask is None:
|
||||
mask = torch.ones(x.shape[:-1], dtype=torch.long, device=x.device)
|
||||
x = x.reshape(batch_size, -1, self.k *
|
||||
dim) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k)
|
||||
x = x.reshape(
|
||||
batch_size, -1, self.k * dim
|
||||
) # rearrange(x, "b (s k) d -> b s (k d)", k=self.k)
|
||||
for layer in self.net:
|
||||
x = layer(x)
|
||||
mask = mask.reshape(
|
||||
batch_size, -1,
|
||||
self.k) # rearrange(mask, "b (s k) -> b s k", k=self.k)
|
||||
batch_size, -1, self.k
|
||||
) # rearrange(mask, "b (s k) -> b s k", k=self.k)
|
||||
mask = mask.any(dim=-1).long()
|
||||
return x, mask
|
||||
|
||||
@@ -503,7 +516,6 @@ class MiDashengLMAudioInputs(TypedDict):
|
||||
|
||||
|
||||
class MiDashengLMProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
|
||||
@@ -522,9 +534,7 @@ class MiDashengLMProcessingInfo(BaseProcessingInfo):
|
||||
return 160000
|
||||
|
||||
|
||||
class MiDashengLMDummyInputsBuilder(
|
||||
BaseDummyInputsBuilder[MiDashengLMProcessingInfo]):
|
||||
|
||||
class MiDashengLMDummyInputsBuilder(BaseDummyInputsBuilder[MiDashengLMProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
@@ -547,16 +557,17 @@ class MiDashengLMDummyInputsBuilder(
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
return {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=self.info.get_max_audio_len(),
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
"audio": self._get_dummy_audios(
|
||||
length=self.info.get_max_audio_len(),
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class MiDashengLMMultiModalProcessor(
|
||||
BaseMultiModalProcessor[MiDashengLMProcessingInfo]):
|
||||
|
||||
BaseMultiModalProcessor[MiDashengLMProcessingInfo]
|
||||
):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
@@ -578,8 +589,10 @@ class MiDashengLMMultiModalProcessor(
|
||||
(0, min_audio_len - audio.shape[-1]),
|
||||
mode="constant",
|
||||
constant_values=0,
|
||||
) if isinstance(audio, np.ndarray)
|
||||
and audio.shape[-1] < min_audio_len else audio for audio in audios
|
||||
)
|
||||
if isinstance(audio, np.ndarray) and audio.shape[-1] < min_audio_len
|
||||
else audio
|
||||
for audio in audios
|
||||
]
|
||||
|
||||
if processed_audios:
|
||||
@@ -590,7 +603,9 @@ class MiDashengLMMultiModalProcessor(
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
mm_kwargs = dict(**mm_kwargs, )
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
)
|
||||
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
@@ -627,11 +642,13 @@ class MiDashengLMMultiModalProcessor(
|
||||
if audio_length is None:
|
||||
audio_output_lengths = []
|
||||
else:
|
||||
audio_length_np = (audio_length.cpu().numpy() if isinstance(
|
||||
audio_length, torch.Tensor) else audio_length)
|
||||
audio_length_np = (
|
||||
audio_length.cpu().numpy()
|
||||
if isinstance(audio_length, torch.Tensor)
|
||||
else audio_length
|
||||
)
|
||||
audio_output_lengths = [
|
||||
max(1, calculate_mel_frames_dasheng(
|
||||
int(length))) # at least one frame
|
||||
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
|
||||
for length in audio_length_np
|
||||
]
|
||||
|
||||
@@ -708,22 +725,23 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.decoder.make_empty_intermediate_tensors)
|
||||
self.decoder.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
|
||||
if name == "input_values":
|
||||
max_length = max(tensor.shape[1] for tensor in mm_input)
|
||||
padded_mm_input = [
|
||||
torch.nn.functional.pad(tensor,
|
||||
(0, max_length - tensor.shape[1]))
|
||||
if tensor.shape[1] < max_length else tensor
|
||||
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
|
||||
if tensor.shape[1] < max_length
|
||||
else tensor
|
||||
for tensor in mm_input
|
||||
]
|
||||
return torch.concat(padded_mm_input)
|
||||
@@ -731,65 +749,67 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[MiDashengLMAudioInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[MiDashengLMAudioInputs]:
|
||||
input_values = kwargs.pop("input_values", None)
|
||||
audio_length = kwargs.pop("audio_length", None)
|
||||
|
||||
if input_values is None:
|
||||
return None
|
||||
input_values = self._validate_and_reshape_mm_tensor(
|
||||
input_values, "input_values")
|
||||
input_values, "input_values"
|
||||
)
|
||||
audio_length = self._validate_and_reshape_mm_tensor(
|
||||
audio_length, "audio_length")
|
||||
audio_length, "audio_length"
|
||||
)
|
||||
if not isinstance(input_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_values)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_values)}"
|
||||
)
|
||||
|
||||
return MiDashengLMAudioInputs(
|
||||
input_values=input_values,
|
||||
audio_length=audio_length,
|
||||
)
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
|
||||
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
|
||||
# Process audio through encoder and projector
|
||||
input_values = audio_input["input_values"]
|
||||
audio_length = audio_input["audio_length"]
|
||||
|
||||
encoder_out, encoder_atts = self.audio_encoder(input_values,
|
||||
audio_length)
|
||||
encoder_out, encoder_atts = self.audio_encoder(input_values, audio_length)
|
||||
audio_embeddings, _ = self.audio_projector(encoder_out, encoder_atts)
|
||||
audio_embeddings = audio_embeddings.to(
|
||||
audio_input["input_values"].dtype)
|
||||
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
|
||||
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
|
||||
|
||||
audio_length_np = (audio_length.cpu().numpy() if isinstance(
|
||||
audio_length, torch.Tensor) else audio_length)
|
||||
audio_length_np = (
|
||||
audio_length.cpu().numpy()
|
||||
if isinstance(audio_length, torch.Tensor)
|
||||
else audio_length
|
||||
)
|
||||
audio_output_lengths = [
|
||||
max(1, calculate_mel_frames_dasheng(
|
||||
int(length))) # at least one frame
|
||||
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
|
||||
for length in audio_length_np
|
||||
]
|
||||
audio_output_lengths = torch.tensor(audio_output_lengths).to(
|
||||
audio_embeddings.device)
|
||||
audio_embeddings.device
|
||||
)
|
||||
|
||||
audio_feature_mask = torch.arange(
|
||||
max_audio_tokens,
|
||||
device=audio_embeddings.device).unsqueeze(0).expand(
|
||||
batch_size,
|
||||
max_audio_tokens) < audio_output_lengths.unsqueeze(1)
|
||||
max_audio_tokens, device=audio_embeddings.device
|
||||
).unsqueeze(0).expand(
|
||||
batch_size, max_audio_tokens
|
||||
) < audio_output_lengths.unsqueeze(1)
|
||||
|
||||
masked_audio_features = audio_embeddings[audio_feature_mask].view(
|
||||
-1, embed_dim)
|
||||
masked_audio_features = audio_embeddings[audio_feature_mask].view(-1, embed_dim)
|
||||
|
||||
return torch.split(masked_audio_features,
|
||||
audio_output_lengths.tolist())
|
||||
return torch.split(masked_audio_features, audio_output_lengths.tolist())
|
||||
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.decoder
|
||||
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
audio_input = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
if audio_input is None:
|
||||
@@ -828,7 +848,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> Optional[torch.Tensor]:
|
||||
return self.decoder.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self)
|
||||
return loader.load_weights(weights)
|
||||
|
||||
Reference in New Issue
Block a user