[Model][Bugfix] Fix issues in MiDashengLM implementation for quantized models (#25854)
Signed-off-by: zhoukz <me@zhoukz.com>
This commit is contained in:
@@ -22,6 +22,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
|
"""Inference-only MiDashengLM model compatible with HuggingFace weights."""
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import collections.abc
|
import collections.abc
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
@@ -30,10 +31,10 @@ from typing import Any, Callable, Optional, TypedDict, Union, cast
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torchaudio.transforms as audio_transforms
|
import torchaudio.functional as F
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
|
|
||||||
from vllm.attention.layer import MultiHeadAttention
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.model_executor.layers.activation import get_act_fn
|
from vllm.model_executor.layers.activation import get_act_fn
|
||||||
@@ -41,7 +42,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalKwargsItems)
|
MultiModalKwargsItems)
|
||||||
@@ -147,15 +147,19 @@ class DashengMlp(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
out_features = out_features or in_features
|
out_features = out_features or in_features
|
||||||
hidden_features = hidden_features or in_features
|
hidden_features = hidden_features or in_features
|
||||||
self.fc1 = ColumnParallelLinear(input_size=in_features,
|
self.fc1 = ColumnParallelLinear(
|
||||||
|
input_size=in_features,
|
||||||
output_size=hidden_features,
|
output_size=hidden_features,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc1")
|
prefix=f"{prefix}.fc1",
|
||||||
|
)
|
||||||
self.act = get_act_fn("gelu")
|
self.act = get_act_fn("gelu")
|
||||||
self.fc2 = RowParallelLinear(input_size=hidden_features,
|
self.fc2 = RowParallelLinear(
|
||||||
|
input_size=hidden_features,
|
||||||
output_size=out_features,
|
output_size=out_features,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.fc2")
|
prefix=f"{prefix}.fc2",
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x, _ = self.fc1(x)
|
x, _ = self.fc1(x)
|
||||||
@@ -171,7 +175,6 @@ class DashengAttention(nn.Module):
|
|||||||
dim: int,
|
dim: int,
|
||||||
num_heads: int = 8,
|
num_heads: int = 8,
|
||||||
qkv_bias: bool = False,
|
qkv_bias: bool = False,
|
||||||
causal: bool = False,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
@@ -205,33 +208,30 @@ class DashengAttention(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.qkv",
|
prefix=f"{prefix}.qkv",
|
||||||
)
|
)
|
||||||
self.attn = MultiHeadAttention(
|
|
||||||
self.num_heads,
|
|
||||||
self.head_dim,
|
|
||||||
self.scale,
|
|
||||||
num_kv_heads=self.num_kv_heads,
|
|
||||||
)
|
|
||||||
self.proj = RowParallelLinear(
|
self.proj = RowParallelLinear(
|
||||||
input_size=dim,
|
input_size=dim,
|
||||||
output_size=dim,
|
output_size=dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.proj",
|
prefix=f"{prefix}.proj",
|
||||||
)
|
)
|
||||||
self.causal = causal
|
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
|
|
||||||
qkv_out, _ = self.qkv(x)
|
qkv, _ = self.qkv(x)
|
||||||
q, k, v = qkv_out.split([self.q_size, self.kv_size, self.kv_size],
|
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||||
dim=-1)
|
qkv = qkv.permute(2, 0, 3, 1, 4)
|
||||||
|
q, k, v = qkv.unbind(0)
|
||||||
|
|
||||||
attn_out = self.attn(q, k, v)
|
x = scaled_dot_product_attention(
|
||||||
C_local = attn_out.numel() // (B * N) # C_local for parallel
|
q,
|
||||||
attn_out = attn_out.view(B, N, C_local)
|
k,
|
||||||
|
v,
|
||||||
x, _ = self.proj(attn_out)
|
attn_mask=mask[:, None, None, :] if mask is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
x = x.transpose(1, 2).reshape(B, N, C)
|
||||||
|
x, _ = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@@ -280,6 +280,63 @@ class DashengBlock(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class DashengFrontend(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, config: DashengConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
spectrogram_window = torch.hann_window(self.config.win_length)
|
||||||
|
self.register_buffer(
|
||||||
|
"spectrogram_window",
|
||||||
|
spectrogram_window,
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self.spectrogram_window: torch.Tensor
|
||||||
|
|
||||||
|
melscale_fbanks = F.melscale_fbanks(
|
||||||
|
n_freqs=self.config.n_fft // 2 + 1,
|
||||||
|
f_min=self.config.f_min,
|
||||||
|
f_max=self.config.f_max,
|
||||||
|
n_mels=self.config.n_mels,
|
||||||
|
sample_rate=self.config.sample_rate,
|
||||||
|
)
|
||||||
|
self.register_buffer("melscale_fbanks",
|
||||||
|
melscale_fbanks,
|
||||||
|
persistent=False)
|
||||||
|
self.melscale_fbanks: torch.Tensor
|
||||||
|
|
||||||
|
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
|
||||||
|
spectrogram = F.spectrogram(
|
||||||
|
waveform=waveform.to(torch.float32),
|
||||||
|
pad=0,
|
||||||
|
window=self.spectrogram_window,
|
||||||
|
n_fft=self.config.n_fft,
|
||||||
|
hop_length=self.config.hop_length,
|
||||||
|
win_length=self.config.win_length,
|
||||||
|
power=2,
|
||||||
|
normalized=False,
|
||||||
|
center=self.config.center,
|
||||||
|
)
|
||||||
|
mel_spectrogram = (
|
||||||
|
spectrogram.mT @ self.melscale_fbanks.to(torch.float32)).mT
|
||||||
|
# x has shape [batch, freq, time].
|
||||||
|
# F.amplitude_to_DB accepts inputs shaped as:
|
||||||
|
# - [freq, time]
|
||||||
|
# - [channel, freq, time]
|
||||||
|
# - [..., channel, freq, time]
|
||||||
|
# Here we insert a channel dimension of size 1 before calling it,
|
||||||
|
# then remove that extra dimension afterward.
|
||||||
|
log_mel_spectrogram = F.amplitude_to_DB(
|
||||||
|
mel_spectrogram.unsqueeze(1),
|
||||||
|
multiplier=10,
|
||||||
|
amin=1e-10,
|
||||||
|
db_multiplier=0,
|
||||||
|
top_db=120,
|
||||||
|
).squeeze(1)
|
||||||
|
return log_mel_spectrogram.to(waveform.dtype)
|
||||||
|
|
||||||
|
|
||||||
class DashengAudioTransformer(nn.Module):
|
class DashengAudioTransformer(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -293,7 +350,7 @@ class DashengAudioTransformer(nn.Module):
|
|||||||
self.target_length = config.target_length
|
self.target_length = config.target_length
|
||||||
self.hop_length = config.hop_length
|
self.hop_length = config.hop_length
|
||||||
|
|
||||||
self._init_front_end(config)
|
self.front_end = DashengFrontend(config)
|
||||||
|
|
||||||
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
|
self.init_bn = nn.BatchNorm2d(config.n_mels, momentum=0.01)
|
||||||
|
|
||||||
@@ -318,34 +375,10 @@ class DashengAudioTransformer(nn.Module):
|
|||||||
qkv_bias=config.qkv_bias,
|
qkv_bias=config.qkv_bias,
|
||||||
init_values=config.init_values,
|
init_values=config.init_values,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.block{i}",
|
prefix=f"{prefix}.blocks.{i}",
|
||||||
) for i in range(config.depth))
|
) for i in range(config.depth))
|
||||||
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
self.norm = nn.LayerNorm(config.embed_dim, eps=1e-6)
|
||||||
|
|
||||||
def _init_front_end(self, config):
|
|
||||||
with set_default_torch_dtype(torch.float32):
|
|
||||||
self.front_end = nn.Sequential(
|
|
||||||
audio_transforms.MelSpectrogram(
|
|
||||||
f_min=config.f_min,
|
|
||||||
f_max=config.f_max,
|
|
||||||
center=config.center,
|
|
||||||
win_length=config.win_length,
|
|
||||||
hop_length=config.hop_length,
|
|
||||||
sample_rate=config.sample_rate,
|
|
||||||
n_fft=config.n_fft,
|
|
||||||
n_mels=config.n_mels,
|
|
||||||
),
|
|
||||||
audio_transforms.AmplitudeToDB(top_db=120),
|
|
||||||
)
|
|
||||||
|
|
||||||
mel_spectrogram = self.front_end[0]
|
|
||||||
fb = mel_spectrogram.mel_scale.fb
|
|
||||||
win = mel_spectrogram.spectrogram.window
|
|
||||||
mel_spectrogram.mel_scale.fb = fb.to(torch.bfloat16).to(
|
|
||||||
torch.float32)
|
|
||||||
mel_spectrogram.spectrogram.window = win.to(torch.bfloat16).to(
|
|
||||||
torch.float32)
|
|
||||||
|
|
||||||
def forward_features(
|
def forward_features(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
@@ -430,14 +463,16 @@ class AudioProjectorSubsample(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.net.0",
|
prefix=f"{prefix}.net.0",
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
), get_act_fn("gelu"),
|
),
|
||||||
|
get_act_fn("gelu"),
|
||||||
RowParallelLinear(
|
RowParallelLinear(
|
||||||
input_size=out_dim,
|
input_size=out_dim,
|
||||||
output_size=out_dim,
|
output_size=out_dim,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.net.2",
|
prefix=f"{prefix}.net.2",
|
||||||
return_bias=False,
|
return_bias=False,
|
||||||
))
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
batch_size, seq_len, dim = x.shape
|
batch_size, seq_len, dim = x.shape
|
||||||
@@ -534,9 +569,12 @@ class MiDashengLMMultiModalProcessor(
|
|||||||
# + Padding
|
# + Padding
|
||||||
min_audio_len = self.info.get_min_audio_len()
|
min_audio_len = self.info.get_min_audio_len()
|
||||||
processed_audios = [
|
processed_audios = [
|
||||||
np.pad(audio, (0, min_audio_len - audio.shape[-1]),
|
np.pad(
|
||||||
mode='constant',
|
audio,
|
||||||
constant_values=0) if isinstance(audio, np.ndarray)
|
(0, min_audio_len - audio.shape[-1]),
|
||||||
|
mode="constant",
|
||||||
|
constant_values=0,
|
||||||
|
) if isinstance(audio, np.ndarray)
|
||||||
and audio.shape[-1] < min_audio_len else audio for audio in audios
|
and audio.shape[-1] < min_audio_len else audio for audio in audios
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -585,8 +623,8 @@ class MiDashengLMMultiModalProcessor(
|
|||||||
if audio_length is None:
|
if audio_length is None:
|
||||||
audio_output_lengths = []
|
audio_output_lengths = []
|
||||||
else:
|
else:
|
||||||
audio_length_np = audio_length.cpu().numpy() if isinstance(
|
audio_length_np = (audio_length.cpu().numpy() if isinstance(
|
||||||
audio_length, torch.Tensor) else audio_length
|
audio_length, torch.Tensor) else audio_length)
|
||||||
audio_output_lengths = [
|
audio_output_lengths = [
|
||||||
max(1, calculate_mel_frames_dasheng(
|
max(1, calculate_mel_frames_dasheng(
|
||||||
int(length))) # at least one frame
|
int(length))) # at least one frame
|
||||||
@@ -617,6 +655,17 @@ class MiDashengLMMultiModalProcessor(
|
|||||||
dummy_inputs=MiDashengLMDummyInputsBuilder,
|
dummy_inputs=MiDashengLMDummyInputsBuilder,
|
||||||
)
|
)
|
||||||
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||||
|
packed_modules_mapping = {
|
||||||
|
"qkv_proj": [
|
||||||
|
"q_proj",
|
||||||
|
"k_proj",
|
||||||
|
"v_proj",
|
||||||
|
],
|
||||||
|
"gate_up_proj": [
|
||||||
|
"gate_proj",
|
||||||
|
"up_proj",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||||
@@ -660,8 +709,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||||
name: str) -> torch.Tensor:
|
name: str) -> torch.Tensor:
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||||
raise ValueError(f"Incorrect type of {name}. "
|
raise ValueError(
|
||||||
f"Got type: {type(mm_input)}")
|
f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||||
if isinstance(mm_input, torch.Tensor):
|
if isinstance(mm_input, torch.Tensor):
|
||||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||||
|
|
||||||
@@ -710,8 +759,8 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
audio_input["input_values"].dtype)
|
audio_input["input_values"].dtype)
|
||||||
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
|
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
|
||||||
|
|
||||||
audio_length_np = audio_length.cpu().numpy() if isinstance(
|
audio_length_np = (audio_length.cpu().numpy() if isinstance(
|
||||||
audio_length, torch.Tensor) else audio_length
|
audio_length, torch.Tensor) else audio_length)
|
||||||
audio_output_lengths = [
|
audio_output_lengths = [
|
||||||
max(1, calculate_mel_frames_dasheng(
|
max(1, calculate_mel_frames_dasheng(
|
||||||
int(length))) # at least one frame
|
int(length))) # at least one frame
|
||||||
@@ -720,11 +769,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
audio_output_lengths = torch.tensor(audio_output_lengths).to(
|
audio_output_lengths = torch.tensor(audio_output_lengths).to(
|
||||||
audio_embeddings.device)
|
audio_embeddings.device)
|
||||||
|
|
||||||
audio_feature_mask = (torch.arange(
|
audio_feature_mask = torch.arange(
|
||||||
max_audio_tokens,
|
max_audio_tokens,
|
||||||
device=audio_embeddings.device).unsqueeze(0).expand(
|
device=audio_embeddings.device).unsqueeze(0).expand(
|
||||||
batch_size, max_audio_tokens)
|
batch_size,
|
||||||
< audio_output_lengths.unsqueeze(1))
|
max_audio_tokens) < audio_output_lengths.unsqueeze(1)
|
||||||
|
|
||||||
masked_audio_features = audio_embeddings[audio_feature_mask].view(
|
masked_audio_features = audio_embeddings[audio_feature_mask].view(
|
||||||
-1, embed_dim)
|
-1, embed_dim)
|
||||||
@@ -762,10 +811,12 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
|||||||
)
|
)
|
||||||
input_ids = None
|
input_ids = None
|
||||||
|
|
||||||
return self.decoder.model(input_ids,
|
return self.decoder.model(
|
||||||
|
input_ids,
|
||||||
positions,
|
positions,
|
||||||
intermediate_tensors,
|
intermediate_tensors,
|
||||||
inputs_embeds=inputs_embeds)
|
inputs_embeds=inputs_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
def compute_logits(
|
def compute_logits(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user