[Model] Use merge_by_field_config for MM models (M-N) (#26710)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -26,7 +26,7 @@
|
||||
import collections
|
||||
import collections.abc
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from typing import Any, TypeAlias, TypedDict, cast
|
||||
from typing import Annotated, Any, TypeAlias, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -62,6 +62,7 @@ from vllm.multimodal.processing import (
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.midashenglm import DashengConfig
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
||||
@@ -508,11 +509,16 @@ class AudioProjectorSubsample(nn.Module):
|
||||
|
||||
|
||||
# === Audio Inputs === #
|
||||
class MiDashengLMAudioInputs(TypedDict):
|
||||
input_values: torch.Tensor
|
||||
"""Shape: `(num_audios, num_sampling_points)`"""
|
||||
audio_length: torch.Tensor
|
||||
"""Shape: `(num_audios, 1)`"""
|
||||
class MiDashengLMAudioInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
Dimensions:
|
||||
- bn: Batch size * number of audios
|
||||
- p: Number of sampling points
|
||||
"""
|
||||
|
||||
input_values: Annotated[torch.Tensor, TensorShape("n", "p")]
|
||||
audio_length: Annotated[torch.Tensor, TensorShape("n")]
|
||||
|
||||
|
||||
class MiDashengLMProcessingInfo(BaseProcessingInfo):
|
||||
@@ -676,6 +682,8 @@ class MiDashengLMMultiModalProcessor(
|
||||
dummy_inputs=MiDashengLMDummyInputsBuilder,
|
||||
)
|
||||
class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@@ -728,26 +736,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.decoder.make_empty_intermediate_tensors
|
||||
)
|
||||
|
||||
def _validate_and_reshape_mm_tensor(
|
||||
self, mm_input: object, name: str
|
||||
) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
raise ValueError(f"Incorrect type of {name}. Got type: {type(mm_input)}")
|
||||
if isinstance(mm_input, torch.Tensor):
|
||||
return mm_input.reshape(-1, *mm_input.shape[2:])
|
||||
|
||||
if name == "input_values":
|
||||
max_length = max(tensor.shape[1] for tensor in mm_input)
|
||||
padded_mm_input = [
|
||||
torch.nn.functional.pad(tensor, (0, max_length - tensor.shape[1]))
|
||||
if tensor.shape[1] < max_length
|
||||
else tensor
|
||||
for tensor in mm_input
|
||||
]
|
||||
return torch.concat(padded_mm_input)
|
||||
|
||||
return torch.concat(mm_input)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> MiDashengLMAudioInputs | None:
|
||||
@@ -756,16 +744,11 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
if input_values is None:
|
||||
return None
|
||||
input_values = self._validate_and_reshape_mm_tensor(
|
||||
input_values, "input_values"
|
||||
)
|
||||
audio_length = self._validate_and_reshape_mm_tensor(
|
||||
audio_length, "audio_length"
|
||||
)
|
||||
if not isinstance(input_values, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of audio input features. "
|
||||
f"Got type: {type(input_values)}"
|
||||
|
||||
if isinstance(input_values, list):
|
||||
input_values = torch.nn.utils.rnn.pad_sequence(
|
||||
input_values,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
return MiDashengLMAudioInputs(
|
||||
@@ -773,7 +756,10 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
audio_length=audio_length,
|
||||
)
|
||||
|
||||
def _process_audio_input(self, audio_input: MiDashengLMAudioInputs) -> torch.Tensor:
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: MiDashengLMAudioInputs,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
# Process audio through encoder and projector
|
||||
input_values = audio_input["input_values"]
|
||||
audio_length = audio_input["audio_length"]
|
||||
@@ -783,17 +769,13 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
audio_embeddings = audio_embeddings.to(audio_input["input_values"].dtype)
|
||||
batch_size, max_audio_tokens, embed_dim = audio_embeddings.shape
|
||||
|
||||
audio_length_np = (
|
||||
audio_length.cpu().numpy()
|
||||
if isinstance(audio_length, torch.Tensor)
|
||||
else audio_length
|
||||
)
|
||||
audio_output_lengths = [
|
||||
max(1, calculate_mel_frames_dasheng(int(length))) # at least one frame
|
||||
for length in audio_length_np
|
||||
for length in audio_length.tolist()
|
||||
]
|
||||
audio_output_lengths = torch.tensor(audio_output_lengths).to(
|
||||
audio_embeddings.device
|
||||
audio_output_lengths = torch.tensor(
|
||||
audio_output_lengths,
|
||||
device=audio_embeddings.device,
|
||||
)
|
||||
|
||||
audio_feature_mask = torch.arange(
|
||||
@@ -826,14 +808,6 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
elif inputs_embeds is None:
|
||||
multimodal_embeddings = self.get_multimodal_embeddings(**kwargs)
|
||||
inputs_embeds = self.get_input_embeddings(
|
||||
input_ids,
|
||||
multimodal_embeddings,
|
||||
is_multimodal=input_ids == self.config.audio_token_id,
|
||||
)
|
||||
input_ids = None
|
||||
|
||||
return self.decoder.model(
|
||||
input_ids,
|
||||
|
||||
Reference in New Issue
Block a user