[Model] Use merge_by_field_config for MM models (M-N) (#26710)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-10-14 01:27:01 +08:00
committed by GitHub
parent e3b90c1ba2
commit afc47e4de7
11 changed files with 127 additions and 331 deletions

View File

@@ -26,7 +26,7 @@
import collections
import collections.abc
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, TypeAlias, TypedDict, cast
from typing import Annotated, Any, TypeAlias, cast
import numpy as np
import torch
@@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.midashenglm import DashengConfig
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
@@ -508,11 +509,16 @@ class AudioProjectorSubsample(nn.Module):
# === Audio Inputs === #
class MiDashengLMAudioInputs(TypedDict):
input_values: torch.Tensor
"""Shape: `(num_audios, num_sampling_points)`"""
audio_length: torch.Tensor
"""Shape: `(num_audios, 1)`"""
class MiDashengLMAudioInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of audios
- p: Number of sampling points
"""
input_values: Annotated[torch.Tensor, TensorShape("n", "p")]
audio_length: Annotated[torch.Tensor, TensorShape("n")]
class MiDashengLMProcessingInfo(BaseProcessingInfo):
@@ -676,6 +682,8 @@ class MiDashengLMMultiModalProcessor(
dummy_inputs=MiDashengLMDummyInputsBuilder,
)
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
@@ -728,26 +736,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
self.decoder.make_empty_intermediate_tensors
)
def _validate_and_reshape_mm_tensor(
self, mm_input: object, name: str
) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
if isinstance(mm_input, torch.Tensor):
return mm_input.reshape(-1, *mm_input.shape[2:])
if name == "input_values":
max_length = max(tensor.shape[1] for tensor in mm_input)
padded_mm_input = [
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
if tensor.shape[1] < max_length
else tensor
for tensor in mm_input
]
return torch.concat(padded_mm_input)
return torch.concat(mm_input)
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> MiDashengLMAudioInputs | None:
@@ -756,16 +744,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
if input_values is None:
return None
input_values = self._validate_and_reshape_mm_tensor(
input_values, "input_values"
)
audio_length = self._validate_and_reshape_mm_tensor(
audio_length, "audio_length"
)
if not isinstance(input_values, (torch.Tensor, list)):
raise ValueError(
"Incorrect type of audio input features. "
f"Got type: {type(input_values)}"
if isinstance(input_values, list):
input_values = torch.nn.utils.rnn.pad_sequence(
input_values,
batch_first=True,
)
return MiDashengLMAudioInputs(
@@ -773,7 +756,10 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_length=audio_length,
)
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
def _process_audio_input(
self,
audio_input: MiDashengLMAudioInputs,
) -> tuple[torch.Tensor, ...]:
# Process audio through encoder and projector
input_values = audio_input["input_values"]
audio_length = audio_input["audio_length"]
@@ -783,17 +769,13 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
audio_length_np = (
audio_length.cpu().numpy()
if isinstance(audio_length, torch.Tensor)
else audio_length
)
audio_output_lengths = [
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
for length in audio_length_np
for length in audio_length.tolist()
]
audio_output_lengths = torch.tensor(audio_output_lengths).to(
audio_embeddings.device
audio_output_lengths = torch.tensor(
audio_output_lengths,
device=audio_embeddings.device,
)
audio_feature_mask = torch.arange(
@@ -826,14 +808,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
inputs_embeds = self.get_input_embeddings(
input_ids,
multimodal_embeddings,
is_multimodal=input_ids == self.config.audio_token_id,
)
input_ids = None
return self.decoder.model(
input_ids,