[Model] MiniCPM-V/O supports V1 (#15487)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-27 21:07:29 +08:00
committed by GitHub
parent 8063dfc61a
commit ac5bc615b0
4 changed files with 573 additions and 594 deletions

View File

@@ -23,8 +23,8 @@
# limitations under the License.
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Any, Callable, Literal, Optional, Set, Tuple, TypedDict,
Union)
import torch
from torch import nn
@@ -42,8 +42,6 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
MultiModalDataParser)
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import flatten_2d_lists
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
MiniCPMVMultiModalDataParser,
@@ -51,13 +49,14 @@ from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
_minicpmv_field_config)
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
maybe_prefix)
from .vision import scatter_patch_features
CPU_DEVICE = torch.device("cpu")
class MiniCPMOAudioFeatureInputs(TypedDict):
type: Literal["audio_features"]
audio_features: torch.Tensor
audio_features: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
Slice here means chunk. Audio that is too long will be split into slices,
@@ -65,37 +64,40 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
Padding is used therefore `audio_features` is `torch.Tensor`.
"""
audio_feature_lens: torch.Tensor
audio_feature_lens: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_audios * num_slices)`
Shape: `(batch_size * num_audios, num_slices)`
This should be feature length of each audio slice,
which equals to `audio_features.shape[-1]`
"""
audio_bounds: torch.Tensor
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_audios * num_slices, 2)`
A boolean mask indicating which audio embeddings correspond
to patch tokens.
This should be in `(start, stop)` format.
Shape: `(batch_size * num_audios, num_embeds)`
"""
class MiniCPMOAudioEmbeddingInputs(TypedDict):
type: Literal["audio_embeds"]
audio_embeds: torch.Tensor
audio_embeds: Union[torch.Tensor, list[torch.Tensor]]
"""
Shape: `(batch_size * num_images * num_slices, hidden_size)`
Shape: `(batch_size * num_audios, num_slices, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
instead of a batched tensor.
Length of each slice may vary, so pass it as a list.
"""
audio_bounds: torch.Tensor
"""
Shape: `(batch_size * num_audios * num_slices, 2)`
This should be in `(start, stop)` format.
embed_is_patch: Union[torch.Tensor, list[torch.Tensor]]
"""
A boolean mask indicating which audio embeddings correspond
to patch tokens.
Shape: `(batch_size * num_audios, num_embeds)`
"""
@@ -104,11 +106,16 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
audio_features = hf_inputs.get("audio_features", torch.empty(0))
num_audios = len(audio_features)
return dict(
**_minicpmv_field_config(hf_inputs),
audio_features=MultiModalFieldConfig.batched("audio"),
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
audio_embeds=MultiModalFieldConfig.batched("audio"),
audio_embed_is_patch=MultiModalFieldConfig.batched("audio"),
audio_token_id=MultiModalFieldConfig.shared("audio", num_audios),
)
@@ -149,7 +156,7 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
audio_pattern = "(<audio>./</audio>)"
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None, "video": None, "audio": None}
return {**super().get_supported_mm_limits(), "audio": None}
def get_mm_max_tokens_per_item(
self,
@@ -157,11 +164,25 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {
"image": self.get_max_image_tokens(),
"audio": self.get_max_audio_tokens(),
"video": self.get_max_video_tokens(seq_len),
**super().get_mm_max_tokens_per_item(seq_len, mm_counts),
"audio":
self.get_max_audio_tokens(),
}
def get_audio_placeholder(
self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1,
) -> str:
hf_processor = self.get_hf_processor()
return hf_processor.get_audio_placeholder(
audio_lens,
chunk_input=chunk_input,
chunk_length=chunk_length,
)
def get_default_audio_pool_step(self) -> int:
return 2
@@ -197,12 +218,8 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
max_videos = mm_config.get_limit_per_prompt("video")
max_audios = mm_config.get_limit_per_prompt("audio")
# count <image_idx></image_idx> tokens
# which are not in get_max_image_tokens
max_image_tokens = self.get_max_image_tokens(
) * max_images + 4 * max_images
max_audio_tokens = self.get_max_audio_tokens(
) * max_audios + 2 * max_audios
max_image_tokens = self.get_max_image_tokens() * max_images
max_audio_tokens = self.get_max_audio_tokens() * max_audios
max_total_frames = self.get_max_video_frames(seq_len -
max_image_tokens -
max_audio_tokens)
@@ -224,20 +241,20 @@ class MiniCPMODummyInputsBuilder(
processor_inputs = super().get_dummy_processor_inputs(
seq_len, mm_counts)
mm_data = {
"image":
processor_inputs.mm_data["image"],
"video":
processor_inputs.mm_data["video"],
audio_prompt_texts = self.info.audio_pattern * num_audios
audio_mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
audio_prompt_texts = self.info.audio_pattern * num_audios
return ProcessorInputs(prompt_text=processor_inputs.prompt_text + \
audio_prompt_texts,
mm_data=mm_data)
return ProcessorInputs(
prompt_text=processor_inputs.prompt_text + audio_prompt_texts,
mm_data={
**processor_inputs.mm_data,
**audio_mm_data,
},
)
class MiniCPMOMultiModalProcessor(
@@ -247,22 +264,17 @@ class MiniCPMOMultiModalProcessor(
return MiniCPMOMultiModalDataParser(
target_sr=self.info.get_default_audio_sampling_rate())
def get_audio_prompt_texts(self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1) -> str:
return self.info.get_hf_processor().get_audio_placeholder(
audio_lens, chunk_input, chunk_length)
def get_special_tokens(self) -> Dict[str, torch.Tensor]:
tokenizer = self.info.get_tokenizer()
special_tokens = super().get_special_tokens()
if hasattr(tokenizer, "audio_start_id"):
special_tokens["audio_start_id"] = torch.tensor(
tokenizer.audio_start_id)
special_tokens["audio_end_id"] = torch.tensor(
tokenizer.audio_end_id)
return special_tokens
def get_audio_prompt_texts(
self,
audio_lens: int,
chunk_input: bool = True,
chunk_length: int = 1,
) -> str:
return self.info.get_audio_placeholder(
audio_lens,
chunk_input=chunk_input,
chunk_length=chunk_length,
)
def process_audios(
self,
@@ -274,32 +286,65 @@ class MiniCPMOMultiModalProcessor(
parsed_audios = (self._get_data_parser().parse_mm_data({
"audio": audios
}).get_items("audio", AudioProcessorItems))
}).get_items("audio",
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs, "chunk_input": True
},
out_keys={"audio_features", "audio_feature_lens"},
)
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
audio_inputs = {}
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
unpadded_audio_features = [
feat[:, :feature_len] for feat, feature_len in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
audio_lens = [
self.info.get_audio_len_by_num_chunks(
sum(map(len,
parsed_audios.get(i)["audio_embeds"])))
for i in range(len(parsed_audios))
]
else:
audio_inputs = self._base_call_hf_processor(
prompts=[self.info.audio_pattern] * len(parsed_audios),
mm_data={"audios": [[audio] for audio in parsed_audios]},
mm_kwargs={
**mm_kwargs,
"chunk_input": True,
},
out_keys={"audio_features", "audio_feature_lens"},
)
# Avoid padding since we need the output for each audio to be
# independent of other audios for the cache to work correctly
unpadded_audio_features = [
feat[:, :feature_len] for feat, feature_len in zip(
audio_inputs["audio_features"],
audio_inputs["audio_feature_lens"],
)
]
audio_inputs["audio_features"] = unpadded_audio_features
audio_lens = [
parsed_audios.get_audio_length(i)
for i in range(len(parsed_audios))
]
audio_repl_features = [
self.get_audio_prompt_texts(audio_len) for audio_len in audio_lens
]
audio_inputs["audio_features"] = unpadded_audio_features
tokenizer = self.info.get_tokenizer()
audio_repls_feature_tokens = [
tokenizer.encode(audio_repl, add_special_tokens=False)
for audio_repl in audio_repl_features
]
embed_is_patch = [
self.get_embed_is_patch(audio_repl_tokens)
for audio_repl_tokens in audio_repls_feature_tokens
]
audio_inputs["audio_embed_is_patch"] = embed_is_patch
unk_token_id = tokenizer.get_vocab()["<unk>"]
audio_inputs["audio_token_id"] = torch.tensor(unk_token_id)
return audio_inputs
def get_placeholder_match_pattern(self) -> str:
return r"\(<(image|video|audio)>./</\1>\)"
def process_mm_inputs(
self,
mm_data: Mapping[str, object],
@@ -331,8 +376,7 @@ class MiniCPMOMultiModalProcessor(
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
audio_len = self.info.get_audio_len_by_num_chunks(
sum(chunk_embeds.shape[0]
for chunk_embeds in single_audio_embeds))
sum(map(len, single_audio_embeds)))
else:
audio_len = audios.get_audio_length(item_idx)
@@ -514,6 +558,8 @@ class MiniCPMO(MiniCPMV2_6):
self.apm = self.init_audio_module(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "apm"))
self.audio_token_id = None
def init_audio_module(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Do not use parameters temporarily
audio_config = self.config.audio_config
@@ -563,18 +609,30 @@ class MiniCPMO(MiniCPMV2_6):
return input_lengths_after_cnn, input_lengths_after_pooling
# Copied from HF repo of MiniCPM-o-2_6,
# designed for batched inputs and outputs
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
chunk_length: int) -> list[torch.Tensor]:
wavforms = data.get(
"audio_features",
[]) # (bs, 80, frames) or [], multi audios need filled in advance
audio_feature_lens_raw = [data.get("audio_feature_lens",
[])] # list, [[x1, x2], [y1], [z1]]
def get_audio_hidden_states(
self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
chunk_length = self.config.audio_chunk_length
if len(wavforms) == 0:
return []
# (bs, 80, frames) or [], multi audios need filled in advance
wavforms_raw = data["audio_features"]
if isinstance(wavforms_raw, list):
B = len(wavforms_raw)
C = wavforms_raw[0].shape[-2]
L = max(item.shape[-1] for item in wavforms_raw)
device = wavforms_raw[0].device
dtype = wavforms_raw[0].dtype
wavforms = torch.zeros((B, C, L), dtype=dtype, device=device)
for i, wavforms_item in enumerate(wavforms_raw):
L_item = wavforms_item.shape[-1]
wavforms[i, ..., :L_item] = wavforms_item
else:
wavforms = wavforms_raw
# list, [[x1, x2], [y1], [z1]]
audio_feature_lens_raw = data["audio_feature_lens"]
if isinstance(audio_feature_lens_raw, torch.Tensor):
audio_feature_lens_raw = audio_feature_lens_raw.unbind(0)
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
batch_size, _, max_mel_seq_len = wavforms.shape
@@ -625,159 +683,104 @@ class MiniCPMO(MiniCPMV2_6):
num_audio_tokens = feature_lens_after_pooling
final_audio_embeds = []
final_audio_embeds = list[torch.Tensor]()
idx = 0
for i in range(len(audio_feature_lens_raw)):
target_audio_embeds = []
target_audio_embeds_lst = list[torch.Tensor]()
for _ in range(len(audio_feature_lens_raw[i])):
target_audio_embeds.append(
target_audio_embeds_lst.append(
audio_embeds[idx, :num_audio_tokens[idx], :])
idx += 1
final_audio_embeds.append(target_audio_embeds)
final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
return final_audio_embeds
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
audio_inputs: MiniCPMOAudioInputs,
chunk_length: int) -> torch.Tensor:
device, dtype = vlm_embedding.device, vlm_embedding.dtype
if audio_inputs["type"] == "audio_embeds":
audio_embeddings = [
item.to(device=device, dtype=dtype)
for item in audio_inputs["audio_embeds"]
]
else:
audio_embeddings = self.get_audio_hidden_states(
audio_inputs, chunk_length)[0]
if audio_embeddings is None or len(audio_embeddings) == 0:
return vlm_embedding
audio_bounds = audio_inputs["audio_bounds"]
if self.config.chunk_input:
audio_embs = torch.cat(audio_embeddings, dim=0).to(device=device,
dtype=dtype)
audio_start_pos = 0
for bound in audio_bounds:
audio_len = bound[1] - bound[0]
vlm_embedding[bound[0]:bound[1]] = audio_embs[
audio_start_pos:audio_start_pos + audio_len, :]
audio_start_pos += audio_len
else:
for embs, bound in zip(audio_embeddings, audio_bounds):
audio_indices = torch.arange(bound[0],
bound[1],
dtype=torch.long).to(device)
if embs.shape[0] != len(audio_indices):
raise ValueError(
"Shape mismatch: Trying to assign embeddings "
f"of shape {embs.shape} "
f"to input indices of length {len(audio_indices)}")
vlm_embedding[audio_indices] = embs.to(dtype)
return vlm_embedding
def _get_audio_bounds(self, input_ids: torch.Tensor,
audio_start_id: torch.Tensor,
audio_end_id: torch.Tensor) -> torch.Tensor:
audio_start_tokens, = torch.where(input_ids == audio_start_id[0])
audio_start_tokens += 1
audio_end_tokens, = torch.where(input_ids == audio_end_id[0])
valid_audio_nums = max(len(audio_start_tokens), len(audio_end_tokens))
return torch.hstack([
audio_start_tokens[:valid_audio_nums].unsqueeze(-1),
audio_end_tokens[:valid_audio_nums].unsqueeze(-1)
])
def _parse_and_validate_audio_inputs(
self, input_ids: torch.Tensor,
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
def _parse_and_validate_audio_input(
self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
audio_features = kwargs.pop("audio_features", None)
audio_embeds = kwargs.pop("audio_embeds", None)
if audio_features is None and audio_embeds is None:
return None
audio_start_id = kwargs.pop("audio_start_id")
if not isinstance(audio_start_id, torch.Tensor):
raise ValueError("Incorrect type of audio_start_id. "
f"Got type: {type(audio_start_id)}")
audio_token_id = kwargs.pop("audio_token_id")
if audio_token_id is not None:
assert isinstance(audio_token_id, torch.Tensor)
self.mm_token_ids.add(audio_token_id.flatten().unique().item())
audio_end_id = kwargs.pop("audio_end_id")
if not isinstance(audio_end_id, torch.Tensor):
raise ValueError("Incorrect type of audio_end_id. "
f"Got type: {type(audio_end_id)}")
audio_embed_is_patch = kwargs.pop("audio_embed_is_patch")
if not isinstance(audio_embed_is_patch, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embed_is_patch. "
f"Got type: {type(audio_embed_is_patch)}")
audio_embed_is_patch = flatten_bn(audio_embed_is_patch)
if audio_embeds is not None:
if not isinstance(audio_embeds, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_embeds. "
f"Got type: {type(audio_embeds)}")
audio_embeds_flat = flatten_bn(audio_embeds)
return MiniCPMOAudioEmbeddingInputs(
type="audio_embeds",
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
audio_embeds=audio_embeds_flat,
embed_is_patch=audio_embed_is_patch,
)
if audio_features is not None:
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_features. "
f"Got type: {type(audio_features)}")
if not isinstance(audio_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_features. "
f"Got type: {type(audio_features)}")
audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}")
audio_feature_lens = kwargs.pop("audio_feature_lens")
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio_feature_lens. "
f"Got type: {type(audio_feature_lens)}")
return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=flatten_bn(audio_features, concat=True),
audio_feature_lens=flatten_bn(
flatten_2d_lists(audio_feature_lens), concat=True),
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
audio_end_id),
)
audio_features_flat = flatten_bn(audio_features)
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
raise AssertionError("This line should be unreachable.")
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
**kwargs: object):
image_inputs = self._parse_and_validate_image_inputs(
input_ids, **kwargs)
if not any("audio" in key for key in kwargs):
return image_inputs, None
audio_inputs = self._parse_and_validate_audio_inputs(
input_ids, **kwargs)
return image_inputs, audio_inputs
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: Any,
) -> torch.Tensor:
if intermediate_tensors is not None:
vlm_embeddings = None
else:
image_inputs, audio_inputs = \
self._parse_and_validate_inputs(input_ids, **kwargs)
vlm_embeddings = self.get_embedding_with_vision(
input_ids, image_inputs)
if audio_inputs is not None:
vlm_embeddings = self.get_embedding_with_audios(
vlm_embeddings, audio_inputs,
self.config.audio_chunk_length)
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None
output = self.llm.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=vlm_embeddings,
return MiniCPMOAudioFeatureInputs(
type="audio_features",
audio_features=audio_features_flat,
audio_feature_lens=audio_feature_lens_flat,
embed_is_patch=audio_embed_is_patch,
)
return output
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
modalities = super()._parse_and_validate_multimodal_inputs(**kwargs)
# Preserve the order of modalities if there are multiple of them
# from the order of kwargs.
for input_key in kwargs:
if input_key in ("audio_features",
"audio_embeds") and "audios" not in modalities:
modalities["audios"] = self._parse_and_validate_audio_input(
**kwargs)
return modalities
def _process_audio_input(
self,
audio_input: MiniCPMOAudioInputs,
) -> Union[torch.Tensor, list[torch.Tensor]]:
if audio_input["type"] == "audio_embeds":
return audio_input["audio_embeds"]
return self.get_audio_hidden_states(audio_input)
def _process_multimodal_inputs(self, modalities: dict):
multimodal_embeddings = super()._process_multimodal_inputs(modalities)
for modality in modalities:
if modality == "audios":
audio_input = modalities["audios"]
audio_features = self._process_audio_input(audio_input)
multimodal_embeddings += tuple(
scatter_patch_features(
audio_features,
audio_input["embed_is_patch"],
))
return multimodal_embeddings