[Model] MiniCPM-V/O supports V1 (#15487)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -23,8 +23,8 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
@@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
_minicpmv_field_config)
|
||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
maybe_prefix)
|
||||
from .vision import scatter_patch_features
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
audio_features: torch.Tensor
|
||||
audio_features: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||
Slice here means chunk. Audio that is too long will be split into slices,
|
||||
@@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
Padding is used therefore `audio_features` is `torch.Tensor`.
|
||||
"""
|
||||
|
||||
audio_feature_lens: torch.Tensor
|
||||
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices)`
|
||||
Shape: `(batch_size * num_audios, num_slices)`
|
||||
|
||||
This should be feature length of each audio slice,
|
||||
which equals to `audio_features.shape[-1]`
|
||||
"""
|
||||
|
||||
audio_bounds: torch.Tensor
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
||||
A boolean mask indicating which audio embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
Shape: `(batch_size * num_audios, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
audio_embeds: torch.Tensor
|
||||
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, hidden_size)`
|
||||
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
instead of a batched tensor.
|
||||
Length of each slice may vary, so pass it as a list.
|
||||
"""
|
||||
audio_bounds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, 2)`
|
||||
|
||||
This should be in `(start, stop)` format.
|
||||
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
A boolean mask indicating which audio embeddings correspond
|
||||
to patch tokens.
|
||||
|
||||
Shape: `(batch_size * num_audios, num_embeds)`
|
||||
"""
|
||||
|
||||
|
||||
@@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_features = hf_inputs.get("audio_features", torch.empty(0))
|
||||
num_audios = len(audio_features)
|
||||
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
|
||||
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
|
||||
)
|
||||
|
||||
|
||||
@@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
audio_pattern = "(<audio>./</audio>)"
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None, "audio": None}
|
||||
return {**super().get_supported_mm_limits(), "audio": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
@@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {
|
||||
"image": self.get_max_image_tokens(),
|
||||
"audio": self.get_max_audio_tokens(),
|
||||
"video": self.get_max_video_tokens(seq_len),
|
||||
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
|
||||
"audio":
|
||||
self.get_max_audio_tokens(),
|
||||
}
|
||||
|
||||
def get_audio_placeholder(
|
||||
self,
|
||||
audio_lens: int,
|
||||
chunk_input: bool = True,
|
||||
chunk_length: int = 1,
|
||||
) -> str:
|
||||
hf_processor = self.get_hf_processor()
|
||||
|
||||
return hf_processor.get_audio_placeholder(
|
||||
audio_lens,
|
||||
chunk_input=chunk_input,
|
||||
chunk_length=chunk_length,
|
||||
)
|
||||
|
||||
def get_default_audio_pool_step(self) -> int:
|
||||
return 2
|
||||
|
||||
@@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
max_videos = mm_config.get_limit_per_prompt("video")
|
||||
max_audios = mm_config.get_limit_per_prompt("audio")
|
||||
|
||||
# count <image_idx></image_idx> tokens
|
||||
# which are not in get_max_image_tokens
|
||||
max_image_tokens = self.get_max_image_tokens(
|
||||
) * max_images + 4 * max_images
|
||||
max_audio_tokens = self.get_max_audio_tokens(
|
||||
) * max_audios + 2 * max_audios
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_audio_tokens = self.get_max_audio_tokens() * max_audios
|
||||
max_total_frames = self.get_max_video_frames(seq_len -
|
||||
max_image_tokens -
|
||||
max_audio_tokens)
|
||||
@@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder(
|
||||
|
||||
processor_inputs = super().get_dummy_processor_inputs(
|
||||
seq_len, mm_counts)
|
||||
mm_data = {
|
||||
"image":
|
||||
processor_inputs.mm_data["image"],
|
||||
"video":
|
||||
processor_inputs.mm_data["video"],
|
||||
|
||||
audio_prompt_texts = self.info.audio_pattern * num_audios
|
||||
audio_mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
audio_prompt_texts = self.info.audio_pattern * num_audios
|
||||
|
||||
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \
|
||||
audio_prompt_texts,
|
||||
mm_data=mm_data)
|
||||
return ProcessorInputs(
|
||||
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
|
||||
mm_data={
|
||||
**processor_inputs.mm_data,
|
||||
**audio_mm_data,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class MiniCPMOMultiModalProcessor(
|
||||
@@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor(
|
||||
return MiniCPMOMultiModalDataParser(
|
||||
target_sr=self.info.get_default_audio_sampling_rate())
|
||||
|
||||
def get_audio_prompt_texts(self,
|
||||
audio_lens: int,
|
||||
chunk_input: bool = True,
|
||||
chunk_length: int = 1) -> str:
|
||||
return self.info.get_hf_processor().get_audio_placeholder(
|
||||
audio_lens, chunk_input, chunk_length)
|
||||
|
||||
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
special_tokens = super().get_special_tokens()
|
||||
if hasattr(tokenizer, "audio_start_id"):
|
||||
special_tokens["audio_start_id"] = torch.tensor(
|
||||
tokenizer.audio_start_id)
|
||||
special_tokens["audio_end_id"] = torch.tensor(
|
||||
tokenizer.audio_end_id)
|
||||
return special_tokens
|
||||
def get_audio_prompt_texts(
|
||||
self,
|
||||
audio_lens: int,
|
||||
chunk_input: bool = True,
|
||||
chunk_length: int = 1,
|
||||
) -> str:
|
||||
return self.info.get_audio_placeholder(
|
||||
audio_lens,
|
||||
chunk_input=chunk_input,
|
||||
chunk_length=chunk_length,
|
||||
)
|
||||
|
||||
def process_audios(
|
||||
self,
|
||||
@@ -274,32 +286,65 @@ class MiniCPMOMultiModalProcessor(
|
||||
|
||||
parsed_audios = (self._get_data_parser().parse_mm_data({
|
||||
"audio": audios
|
||||
}).get_items("audio", AudioProcessorItems))
|
||||
}).get_items("audio",
|
||||
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
|
||||
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "chunk_input": True
|
||||
},
|
||||
out_keys={"audio_features", "audio_feature_lens"},
|
||||
)
|
||||
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
|
||||
audio_inputs = {}
|
||||
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len] for feat, feature_len in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
audio_lens = [
|
||||
self.info.get_audio_len_by_num_chunks(
|
||||
sum(map(len,
|
||||
parsed_audios.get(i)["audio_embeds"])))
|
||||
for i in range(len(parsed_audios))
|
||||
]
|
||||
else:
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||
mm_kwargs={
|
||||
**mm_kwargs,
|
||||
"chunk_input": True,
|
||||
},
|
||||
out_keys={"audio_features", "audio_feature_lens"},
|
||||
)
|
||||
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len] for feat, feature_len in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
)
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
audio_lens = [
|
||||
parsed_audios.get_audio_length(i)
|
||||
for i in range(len(parsed_audios))
|
||||
]
|
||||
|
||||
audio_repl_features = [
|
||||
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
audio_repls_feature_tokens = [
|
||||
tokenizer.encode(audio_repl, add_special_tokens=False)
|
||||
for audio_repl in audio_repl_features
|
||||
]
|
||||
|
||||
embed_is_patch = [
|
||||
self.get_embed_is_patch(audio_repl_tokens)
|
||||
for audio_repl_tokens in audio_repls_feature_tokens
|
||||
]
|
||||
audio_inputs["audio_embed_is_patch"] = embed_is_patch
|
||||
|
||||
unk_token_id = tokenizer.get_vocab()["<unk>"]
|
||||
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
|
||||
|
||||
return audio_inputs
|
||||
|
||||
def get_placeholder_match_pattern(self) -> str:
|
||||
return r"\(<(image|video|audio)>./</\1>\)"
|
||||
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
@@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor(
|
||||
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
|
||||
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
|
||||
audio_len = self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
sum(map(len, single_audio_embeds)))
|
||||
else:
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
@@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
self.apm = self.init_audio_module(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "apm"))
|
||||
|
||||
self.audio_token_id = None
|
||||
|
||||
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
# Do not use parameters temporarily
|
||||
audio_config = self.config.audio_config
|
||||
@@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
return input_lengths_after_cnn, input_lengths_after_pooling
|
||||
|
||||
# Copied from HF repo of MiniCPM-o-2_6,
|
||||
# designed for batched inputs and outputs
|
||||
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> list[torch.Tensor]:
|
||||
wavforms = data.get(
|
||||
"audio_features",
|
||||
[]) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||
audio_feature_lens_raw = [data.get("audio_feature_lens",
|
||||
[])] # list, [[x1, x2], [y1], [z1]]
|
||||
def get_audio_hidden_states(
|
||||
self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
|
||||
chunk_length = self.config.audio_chunk_length
|
||||
|
||||
if len(wavforms) == 0:
|
||||
return []
|
||||
# (bs, 80, frames) or [], multi audios need filled in advance
|
||||
wavforms_raw = data["audio_features"]
|
||||
if isinstance(wavforms_raw, list):
|
||||
B = len(wavforms_raw)
|
||||
C = wavforms_raw[0].shape[-2]
|
||||
L = max(item.shape[-1] for item in wavforms_raw)
|
||||
device = wavforms_raw[0].device
|
||||
dtype = wavforms_raw[0].dtype
|
||||
|
||||
wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
|
||||
for i, wavforms_item in enumerate(wavforms_raw):
|
||||
L_item = wavforms_item.shape[-1]
|
||||
wavforms[i, ..., :L_item] = wavforms_item
|
||||
else:
|
||||
wavforms = wavforms_raw
|
||||
|
||||
# list, [[x1, x2], [y1], [z1]]
|
||||
audio_feature_lens_raw = data["audio_feature_lens"]
|
||||
if isinstance(audio_feature_lens_raw, torch.Tensor):
|
||||
audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
|
||||
|
||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||
@@ -625,159 +683,104 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
final_audio_embeds = []
|
||||
final_audio_embeds = list[torch.Tensor]()
|
||||
idx = 0
|
||||
for i in range(len(audio_feature_lens_raw)):
|
||||
target_audio_embeds = []
|
||||
target_audio_embeds_lst = list[torch.Tensor]()
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds.append(
|
||||
target_audio_embeds_lst.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
|
||||
final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
|
||||
|
||||
return final_audio_embeds
|
||||
|
||||
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
|
||||
audio_inputs: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
device, dtype = vlm_embedding.device, vlm_embedding.dtype
|
||||
if audio_inputs["type"] == "audio_embeds":
|
||||
audio_embeddings = [
|
||||
item.to(device=device, dtype=dtype)
|
||||
for item in audio_inputs["audio_embeds"]
|
||||
]
|
||||
else:
|
||||
audio_embeddings = self.get_audio_hidden_states(
|
||||
audio_inputs, chunk_length)[0]
|
||||
if audio_embeddings is None or len(audio_embeddings) == 0:
|
||||
return vlm_embedding
|
||||
audio_bounds = audio_inputs["audio_bounds"]
|
||||
if self.config.chunk_input:
|
||||
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
|
||||
dtype=dtype)
|
||||
audio_start_pos = 0
|
||||
for bound in audio_bounds:
|
||||
audio_len = bound[1] - bound[0]
|
||||
vlm_embedding[bound[0]:bound[1]] = audio_embs[
|
||||
audio_start_pos:audio_start_pos + audio_len, :]
|
||||
audio_start_pos += audio_len
|
||||
else:
|
||||
for embs, bound in zip(audio_embeddings, audio_bounds):
|
||||
audio_indices = torch.arange(bound[0],
|
||||
bound[1],
|
||||
dtype=torch.long).to(device)
|
||||
|
||||
if embs.shape[0] != len(audio_indices):
|
||||
raise ValueError(
|
||||
"Shape mismatch: Trying to assign embeddings "
|
||||
f"of shape {embs.shape} "
|
||||
f"to input indices of length {len(audio_indices)}")
|
||||
vlm_embedding[audio_indices] = embs.to(dtype)
|
||||
return vlm_embedding
|
||||
|
||||
def _get_audio_bounds(self, input_ids: torch.Tensor,
|
||||
audio_start_id: torch.Tensor,
|
||||
audio_end_id: torch.Tensor) -> torch.Tensor:
|
||||
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
|
||||
audio_start_tokens += 1
|
||||
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
|
||||
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
|
||||
return torch.hstack([
|
||||
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
|
||||
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
|
||||
])
|
||||
|
||||
def _parse_and_validate_audio_inputs(
|
||||
self, input_ids: torch.Tensor,
|
||||
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
audio_start_id = kwargs.pop("audio_start_id")
|
||||
if not isinstance(audio_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_start_id. "
|
||||
f"Got type: {type(audio_start_id)}")
|
||||
audio_token_id = kwargs.pop("audio_token_id")
|
||||
if audio_token_id is not None:
|
||||
assert isinstance(audio_token_id, torch.Tensor)
|
||||
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
|
||||
|
||||
audio_end_id = kwargs.pop("audio_end_id")
|
||||
if not isinstance(audio_end_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_end_id. "
|
||||
f"Got type: {type(audio_end_id)}")
|
||||
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
|
||||
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embed_is_patch. "
|
||||
f"Got type: {type(audio_embed_is_patch)}")
|
||||
|
||||
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
audio_embeds_flat = flatten_bn(audio_embeds)
|
||||
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
type="audio_embeds",
|
||||
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
|
||||
concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
audio_embeds=audio_embeds_flat,
|
||||
embed_is_patch=audio_embed_is_patch,
|
||||
)
|
||||
|
||||
if audio_features is not None:
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}")
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}")
|
||||
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=flatten_bn(audio_features, concat=True),
|
||||
audio_feature_lens=flatten_bn(
|
||||
flatten_2d_lists(audio_feature_lens), concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
)
|
||||
audio_features_flat = flatten_bn(audio_features)
|
||||
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
||||
**kwargs: object):
|
||||
image_inputs = self._parse_and_validate_image_inputs(
|
||||
input_ids, **kwargs)
|
||||
if not any("audio" in key for key in kwargs):
|
||||
return image_inputs, None
|
||||
audio_inputs = self._parse_and_validate_audio_inputs(
|
||||
input_ids, **kwargs)
|
||||
return image_inputs, audio_inputs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
vlm_embeddings = None
|
||||
else:
|
||||
image_inputs, audio_inputs = \
|
||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
vlm_embeddings = self.get_embedding_with_vision(
|
||||
input_ids, image_inputs)
|
||||
|
||||
if audio_inputs is not None:
|
||||
vlm_embeddings = self.get_embedding_with_audios(
|
||||
vlm_embeddings, audio_inputs,
|
||||
self.config.audio_chunk_length)
|
||||
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
# for `torch.compile` integration
|
||||
input_ids = None
|
||||
|
||||
output = self.llm.model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
inputs_embeds=vlm_embeddings,
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=audio_features_flat,
|
||||
audio_feature_lens=audio_feature_lens_flat,
|
||||
embed_is_patch=audio_embed_is_patch,
|
||||
)
|
||||
return output
|
||||
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("audio_features",
|
||||
"audio_embeds") and "audios" not in modalities:
|
||||
modalities["audios"] = self._parse_and_validate_audio_input(
|
||||
**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
def _process_audio_input(
|
||||
self,
|
||||
audio_input: MiniCPMOAudioInputs,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return audio_input["audio_embeds"]
|
||||
|
||||
return self.get_audio_hidden_states(audio_input)
|
||||
|
||||
def _process_multimodal_inputs(self, modalities: dict):
|
||||
multimodal_embeddings = super()._process_multimodal_inputs(modalities)
|
||||
|
||||
for modality in modalities:
|
||||
if modality == "audios":
|
||||
audio_input = modalities["audios"]
|
||||
audio_features = self._process_audio_input(audio_input)
|
||||
multimodal_embeddings += tuple(
|
||||
scatter_patch_features(
|
||||
audio_features,
|
||||
audio_input["embed_is_patch"],
|
||||
))
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
Reference in New Issue
Block a user