[Model] Use merge_by_field_config for MM models (M-N) (#26710)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -71,7 +71,7 @@ from .minicpmv import (
|
||||
MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
@@ -132,15 +132,11 @@ MiniCPMOAudioInputs: TypeAlias = (
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_features = hf_inputs.get("audio_features", torch.empty(0))
|
||||
num_audios = len(audio_features)
|
||||
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
||||
)
|
||||
|
||||
|
||||
@@ -332,10 +328,6 @@ class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessing
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return audio_inputs
|
||||
|
||||
def process_mm_inputs(
|
||||
@@ -436,12 +428,10 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
past_key_values = None
|
||||
hidden_states = self.self_attn_layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, past_key_values = self.self_attn(
|
||||
hidden_states, _ = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_values,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.dropout, training=self.training
|
||||
@@ -567,8 +557,6 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
|
||||
)
|
||||
|
||||
self.audio_token_id = None
|
||||
|
||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Do not use parameters temporarily
|
||||
audio_config = self.config.audio_config
|
||||
@@ -731,43 +719,18 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
audio_token_id = kwargs.pop("audio_token_id")
|
||||
if audio_token_id is not None:
|
||||
assert isinstance(audio_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}"
|
||||
)
|
||||
|
||||
audio_embeds_flat = flatten_bn(audio_embeds)
|
||||
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
type="audio_embeds",
|
||||
audio_embeds=audio_embeds_flat,
|
||||
)
|
||||
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio_features. Got type: {type(audio_features)}"
|
||||
audio_embeds=audio_embeds,
|
||||
)
|
||||
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError(
|
||||
"Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}"
|
||||
)
|
||||
|
||||
audio_features_flat = flatten_bn(audio_features)
|
||||
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
|
||||
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=audio_features_flat,
|
||||
audio_feature_lens=audio_feature_lens_flat,
|
||||
audio_features=audio_features,
|
||||
audio_feature_lens=audio_feature_lens,
|
||||
)
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
|
||||
Reference in New Issue
Block a user