[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (#25854)

Signed-off-by: zhoukz <me@zhoukz.com>
This commit is contained in:
Zhou Jiahao
2025-09-29 18:59:04 +08:00
committed by GitHub
parent edbaadd91f
commit 8616300ae2

View File

@@ -22,6 +22,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only MiDashengLM model compatible with HuggingFace weights.""" """Inference-only MiDashengLM model compatible with HuggingFace weights."""
import collections import collections
import collections.abc import collections.abc
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
@@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchaudio.transforms as audio_transforms import torchaudio.functional as F
from torch.nn.functional import scaled_dot_product_attention
from transformers import BatchFeature from transformers import BatchFeature
from vllm.attention.layer import MultiHeadAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
@@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargsItems) MultiModalKwargsItems)
@@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
self.fc1 = ColumnParallelLinear(input_size=in_features, self.fc1 = ColumnParallelLinear(
input_size=in_features,
output_size=hidden_features, output_size=hidden_features,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc1") prefix=f"{prefix}.fc1",
)
self.act = get_act_fn("gelu") self.act = get_act_fn("gelu")
self.fc2 = RowParallelLinear(input_size=hidden_features, self.fc2 = RowParallelLinear(
input_size=hidden_features,
output_size=out_features, output_size=out_features,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.fc2") prefix=f"{prefix}.fc2",
)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
x, _ = self.fc1(x) x, _ = self.fc1(x)
@@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
dim: int, dim: int,
num_heads: int = 8, num_heads: int = 8,
qkv_bias: bool = False, qkv_bias: bool = False,
causal: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "", prefix: str = "",
): ):
@@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.qkv", prefix=f"{prefix}.qkv",
) )
self.attn = MultiHeadAttention(
self.num_heads,
self.head_dim,
self.scale,
num_kv_heads=self.num_kv_heads,
)
self.proj = RowParallelLinear( self.proj = RowParallelLinear(
input_size=dim, input_size=dim,
output_size=dim, output_size=dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.proj", prefix=f"{prefix}.proj",
) )
self.causal = causal
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None): def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
B, N, C = x.shape B, N, C = x.shape
qkv_out, _ = self.qkv(x) qkv, _ = self.qkv(x)
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size], qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
dim=-1) qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn_out = self.attn(q, k, v) x = scaled_dot_product_attention(
C_local = attn_out.numel() // (B * N) # C_local for parallel q,
attn_out = attn_out.view(B, N, C_local) k,
v,
x, _ = self.proj(attn_out) attn_mask=mask[:, None, None, :] if mask is not None else None,
)
x = x.transpose(1, 2).reshape(B, N, C)
x, _ = self.proj(x)
return x return x
@@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
return x return x
class DashengFrontend(nn.Module):
def __init__(self, config: DashengConfig):
super().__init__()
self.config = config
spectrogram_window = torch.hann_window(self.config.win_length)
self.register_buffer(
"spectrogram_window",
spectrogram_window,
persistent=False,
)
self.spectrogram_window: torch.Tensor
melscale_fbanks = F.melscale_fbanks(
n_freqs=self.config.n_fft // 2 + 1,
f_min=self.config.f_min,
f_max=self.config.f_max,
n_mels=self.config.n_mels,
sample_rate=self.config.sample_rate,
)
self.register_buffer("melscale_fbanks",
melscale_fbanks,
persistent=False)
self.melscale_fbanks: torch.Tensor
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
spectrogram = F.spectrogram(
waveform=waveform.to(torch.float32),
pad=0,
window=self.spectrogram_window,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
win_length=self.config.win_length,
power=2,
normalized=False,
center=self.config.center,
)
mel_spectrogram = (
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
# x has shape [batch, freq, time].
# F.amplitude_to_DB accepts inputs shaped as:
# - [freq, time]
# - [channel, freq, time]
# - [..., channel, freq, time]
# Here we insert a channel dimension of size 1 before calling it,
# then remove that extra dimension afterward.
log_mel_spectrogram = F.amplitude_to_DB(
mel_spectrogram.unsqueeze(1),
multiplier=10,
amin=1e-10,
db_multiplier=0,
top_db=120,
).squeeze(1)
return log_mel_spectrogram.to(waveform.dtype)
class DashengAudioTransformer(nn.Module): class DashengAudioTransformer(nn.Module):
def __init__( def __init__(
@@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
self.target_length = config.target_length self.target_length = config.target_length
self.hop_length = config.hop_length self.hop_length = config.hop_length
self._init_front_end(config) self.front_end = DashengFrontend(config)
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01) self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
@@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
qkv_bias=config.qkv_bias, qkv_bias=config.qkv_bias,
init_values=config.init_values, init_values=config.init_values,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.block{i}", prefix=f"{prefix}.blocks.{i}",
) for i in range(config.depth)) ) for i in range(config.depth))
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6) self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
def _init_front_end(self, config):
with set_default_torch_dtype(torch.float32):
self.front_end = nn.Sequential(
audio_transforms.MelSpectrogram(
f_min=config.f_min,
f_max=config.f_max,
center=config.center,
win_length=config.win_length,
hop_length=config.hop_length,
sample_rate=config.sample_rate,
n_fft=config.n_fft,
n_mels=config.n_mels,
),
audio_transforms.AmplitudeToDB(top_db=120),
)
mel_spectrogram = self.front_end[0]
fb = mel_spectrogram.mel_scale.fb
win = mel_spectrogram.spectrogram.window
mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to(
torch.float32)
mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to(
torch.float32)
def forward_features( def forward_features(
self, self,
x: torch.Tensor, x: torch.Tensor,
@@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.net.0", prefix=f"{prefix}.net.0",
return_bias=False, return_bias=False,
), get_act_fn("gelu"), ),
get_act_fn("gelu"),
RowParallelLinear( RowParallelLinear(
input_size=out_dim, input_size=out_dim,
output_size=out_dim, output_size=out_dim,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.net.2", prefix=f"{prefix}.net.2",
return_bias=False, return_bias=False,
)) ),
)
def forward(self, x, mask=None): def forward(self, x, mask=None):
batch_size, seq_len, dim = x.shape batch_size, seq_len, dim = x.shape
@@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
# + Padding # + Padding
min_audio_len = self.info.get_min_audio_len() min_audio_len = self.info.get_min_audio_len()
processed_audios = [ processed_audios = [
np.pad(audio, (0, min_audio_len - audio.shape[-1]), np.pad(
mode='constant', audio,
constant_values=0) if isinstance(audio, np.ndarray) (0, min_audio_len - audio.shape[-1]),
mode="constant",
constant_values=0,
) if isinstance(audio, np.ndarray)
and audio.shape[-1] < min_audio_len else audio for audio in audios and audio.shape[-1] < min_audio_len else audio for audio in audios
] ]
@@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
if audio_length is None: if audio_length is None:
audio_output_lengths = [] audio_output_lengths = []
else: else:
audio_length_np = audio_length.cpu().numpy() if isinstance( audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [ audio_output_lengths = [
max(1, calculate_mel_frames_dasheng( max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame int(length))) # at least one frame
@@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs=MiDashengLMDummyInputsBuilder, dummy_inputs=MiDashengLMDummyInputsBuilder,
) )
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP): class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
"gate_up_proj": [
"gate_proj",
"up_proj",
],
}
@classmethod @classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]: def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
@@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def _validate_and_reshape_mm_tensor(self, mm_input: object, def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor: name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)): if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. " raise ValueError(
f"Got type: {type(mm_input)}") f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor): if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:]) return mm_input.reshape(-1, *mm_input.shape[2:])
@@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_input["input_values"].dtype) audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
audio_length_np = audio_length.cpu().numpy() if isinstance( audio_length_np = (audio_length.cpu().numpy() if isinstance(
audio_length, torch.Tensor) else audio_length audio_length, torch.Tensor) else audio_length)
audio_output_lengths = [ audio_output_lengths = [
max(1, calculate_mel_frames_dasheng( max(1, calculate_mel_frames_dasheng(
int(length))) # at least one frame int(length))) # at least one frame
@@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_output_lengths = torch.tensor(audio_output_lengths).to( audio_output_lengths = torch.tensor(audio_output_lengths).to(
audio_embeddings.device) audio_embeddings.device)
audio_feature_mask = (torch.arange( audio_feature_mask = torch.arange(
max_audio_tokens, max_audio_tokens,
device=audio_embeddings.device).unsqueeze(0).expand( device=audio_embeddings.device).unsqueeze(0).expand(
batch_size, max_audio_tokens) batch_size,
< audio_output_lengths.unsqueeze(1)) max_audio_tokens) < audio_output_lengths.unsqueeze(1)
masked_audio_features = audio_embeddings[audio_feature_mask].view( masked_audio_features = audio_embeddings[audio_feature_mask].view(
-1, embed_dim) -1, embed_dim)
@@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
) )
input_ids = None input_ids = None
return self.decoder.model(input_ids, return self.decoder.model(
input_ids,
positions, positions,
intermediate_tensors, intermediate_tensors,
inputs_embeds=inputs_embeds) inputs_embeds=inputs_embeds,
)
def compute_logits( def compute_logits(
self, self,