[Misc] Clean up MiniCPM-V/O code (#15337)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -23,7 +23,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import (Any, Callable, Dict, List, Literal, Optional, Set, Tuple,
|
||||
from typing import (Any, Callable, Dict, Literal, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import torch
|
||||
@@ -43,24 +43,26 @@ from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
|
||||
from vllm.multimodal.profiling import ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import flatten_2d_lists
|
||||
|
||||
from .minicpmv import (MiniCPMV2_6, MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config)
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, maybe_prefix
|
||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
maybe_prefix)
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
|
||||
class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: torch.Tensor
|
||||
audio_features: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_audios * num_slices, num_channels, length)`
|
||||
Slice here means chunk. Audio that is too long will be split into slices,
|
||||
which is the same as image.
|
||||
Padding is used therefore `data` is `torch.Tensor`.
|
||||
Padding is used therefore `audio_features` is `torch.Tensor`.
|
||||
"""
|
||||
|
||||
audio_feature_lens: torch.Tensor
|
||||
@@ -68,7 +70,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
Shape: `(batch_size * num_audios * num_slices)`
|
||||
|
||||
This should be feature length of each audio slice,
|
||||
which equals to `data.shape[-1]`
|
||||
which equals to `audio_features.shape[-1]`
|
||||
"""
|
||||
|
||||
audio_bounds: torch.Tensor
|
||||
@@ -81,7 +83,7 @@ class MiniCPMOAudioFeatureInputs(TypedDict):
|
||||
|
||||
class MiniCPMOAudioEmbeddingInputs(TypedDict):
|
||||
type: Literal["audio_embeds"]
|
||||
data: List[torch.Tensor]
|
||||
audio_embeds: torch.Tensor
|
||||
"""
|
||||
Shape: `(batch_size * num_images * num_slices, hidden_size)`
|
||||
|
||||
@@ -102,18 +104,11 @@ MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
audio_num_slices = hf_inputs.get("audio_num_slices", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
**_minicpmv_field_config(hf_inputs),
|
||||
audio_features=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_feature_lens=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_num_slices=MultiModalFieldConfig.batched("audio"),
|
||||
audio_orders_in_mm_data=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", audio_num_slices),
|
||||
audio_features=MultiModalFieldConfig.batched("audio"),
|
||||
audio_feature_lens=MultiModalFieldConfig.batched("audio"),
|
||||
audio_embeds=MultiModalFieldConfig.batched("audio"),
|
||||
)
|
||||
|
||||
|
||||
@@ -153,9 +148,6 @@ class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
|
||||
class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
audio_pattern = "(<audio>./</audio>)"
|
||||
|
||||
def get_supported_mm_modalities(self) -> List[str]:
|
||||
return ["image", "video", "audio"]
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None, "video": None, "audio": None}
|
||||
|
||||
@@ -277,95 +269,47 @@ class MiniCPMOMultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
mm_data = dict(mm_data)
|
||||
if (audios := mm_data.get("audios")) is None:
|
||||
return {}
|
||||
|
||||
audios = mm_data.pop("audios", [])
|
||||
audio_embeds = mm_data.pop("audio_embeds", [])
|
||||
if isinstance(audios, (list, torch.Tensor)) and len(audios) > 0:
|
||||
audio_outputs = {
|
||||
"audio_lens": [],
|
||||
"audio_features": [],
|
||||
"audio_feature_lens": [],
|
||||
"audio_num_segments": []
|
||||
}
|
||||
for audio in audios:
|
||||
single_audio_outputs = super().call_base_hf_processor(
|
||||
prompt=self.info.audio_pattern,
|
||||
mm_data={
|
||||
"audios": audio,
|
||||
"chunk_input": True
|
||||
},
|
||||
mm_kwargs=mm_kwargs)
|
||||
audio_outputs["audio_lens"].append(len(audio))
|
||||
audio_outputs["audio_features"].append(
|
||||
single_audio_outputs["audio_features"])
|
||||
audio_outputs["audio_num_segments"].append(
|
||||
len(single_audio_outputs["audio_feature_lens"][0]))
|
||||
audio_outputs["audio_feature_lens"] += \
|
||||
single_audio_outputs["audio_feature_lens"]
|
||||
audio_outputs["audio_features"] = [
|
||||
audio_feature for single_audio_features in \
|
||||
audio_outputs["audio_features"]
|
||||
for audio_feature in single_audio_features
|
||||
]
|
||||
audio_outputs["audio_feature_lens"] = torch.cat(
|
||||
audio_outputs["audio_feature_lens"])
|
||||
elif len(audio_embeds):
|
||||
audio_outputs = {
|
||||
"audio_lens": [
|
||||
self.info.get_audio_len_by_num_chunks(
|
||||
sum(chunk_embeds.shape[0]
|
||||
for chunk_embeds in single_audio_embeds))
|
||||
for single_audio_embeds in audio_embeds
|
||||
],
|
||||
"audio_embeds": [
|
||||
chunk_embeds for single_audio_embeds in audio_embeds
|
||||
for chunk_embeds in single_audio_embeds
|
||||
],
|
||||
"audio_num_segments": [
|
||||
len(single_audio_embeds)
|
||||
for single_audio_embeds in audio_embeds
|
||||
]
|
||||
}
|
||||
else:
|
||||
audio_outputs = {}
|
||||
return audio_outputs
|
||||
parsed_audios = (self._get_data_parser().parse_mm_data({
|
||||
"audio": audios
|
||||
}).get_items("audio", AudioProcessorItems))
|
||||
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "chunk_input": True
|
||||
},
|
||||
out_keys={"audio_features", "audio_feature_lens"},
|
||||
)
|
||||
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len] for feat, feature_len in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
)
|
||||
]
|
||||
audio_inputs["audio_features"] = unpadded_audio_features
|
||||
|
||||
return audio_inputs
|
||||
|
||||
def get_placeholder_match_pattern(self) -> str:
|
||||
return r"\(<(image|video|audio)>./</\1>\)"
|
||||
|
||||
def get_placeholder_split_pattern(self) -> str:
|
||||
return r"\(<(?:image|video|audio)>./</(?:image|video|audio)>\)"
|
||||
|
||||
def process_mm_inputs(
|
||||
self,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, Mapping[str, NestedTensors]]:
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
return {
|
||||
"image": self.process_images(mm_data, mm_kwargs),
|
||||
"video": self.process_videos(mm_data, mm_kwargs),
|
||||
"audio": self.process_audios(mm_data, mm_kwargs),
|
||||
**super().process_mm_inputs(mm_data, mm_kwargs),
|
||||
**self.process_audios(mm_data, mm_kwargs),
|
||||
}
|
||||
|
||||
def get_modality_num_counter(self, modality: str) -> str:
|
||||
if modality == "audio":
|
||||
return "audio_lens"
|
||||
return super().get_modality_num_counter(modality)
|
||||
|
||||
def get_num_slices_by_modality(self, inputs: Dict[str, object],
|
||||
modality: str, index: int) -> int:
|
||||
if modality == "audio":
|
||||
return inputs["audio"]["audio_num_segments"][index]
|
||||
return super().get_num_slices_by_modality(inputs, modality, index)
|
||||
|
||||
def get_prompt_texts_by_modality(self, inputs: Dict[str, object],
|
||||
modality: str, index: int) -> str:
|
||||
if modality == "audio":
|
||||
return self.get_audio_prompt_texts(
|
||||
inputs["audio"]["audio_lens"][index])
|
||||
return super().get_prompt_texts_by_modality(inputs, modality, index)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
@@ -622,86 +566,84 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
# Copied from HF repo of MiniCPM-o-2_6,
|
||||
# designed for batched inputs and outputs
|
||||
def get_audio_hidden_states(self, data: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
chunk_length: int) -> list[torch.Tensor]:
|
||||
wavforms = data.get(
|
||||
"data",
|
||||
"audio_features",
|
||||
[]) # (bs, 80, frames) or [], multi audios need filled in advance
|
||||
audio_feature_lens_raw = [data.get("audio_feature_lens",
|
||||
[])] # list, [[x1, x2], [y1], [z1]]
|
||||
|
||||
# exist audio
|
||||
if len(wavforms) > 0:
|
||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feature_lens.dtype,
|
||||
device=audio_feature_lens.device).unsqueeze(0).expand(
|
||||
batch_size, max_seq_len))
|
||||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
||||
batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(
|
||||
batch_size, 1, 1, max_seq_len).expand(batch_size, 1,
|
||||
max_seq_len, max_seq_len)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.apm.conv1.weight.dtype,
|
||||
device=self.apm.conv1.weight.device)
|
||||
|
||||
if chunk_length > 0:
|
||||
chunk_num_frame = int(chunk_length * 50)
|
||||
chunk_mask = self.subsequent_chunk_mask(
|
||||
size=max_seq_len,
|
||||
chunk_size=chunk_num_frame,
|
||||
num_left_chunks=-1,
|
||||
device=audio_attention_mask_.device,
|
||||
)
|
||||
audio_attention_mask_ = torch.logical_or(
|
||||
audio_attention_mask_, torch.logical_not(chunk_mask))
|
||||
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
audio_states = self.apm(
|
||||
wavforms, attention_mask=audio_attention_mask).hidden_states[
|
||||
self.audio_encoder_layer]
|
||||
audio_embeds = self.audio_projection_layer(audio_states)
|
||||
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
|
||||
_, feature_lens_after_pooling = \
|
||||
self._get_feat_extract_output_lengths(audio_feature_lens)
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
final_audio_embeds = []
|
||||
idx = 0
|
||||
for i in range(len(audio_feature_lens_raw)):
|
||||
target_audio_embeds = []
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
return final_audio_embeds
|
||||
else:
|
||||
if len(wavforms) == 0:
|
||||
return []
|
||||
|
||||
audio_feature_lens = torch.hstack(audio_feature_lens_raw)
|
||||
batch_size, _, max_mel_seq_len = wavforms.shape
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feature_lens.dtype,
|
||||
device=audio_feature_lens.device).unsqueeze(0).expand(
|
||||
batch_size, max_seq_len))
|
||||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
||||
batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(
|
||||
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
|
||||
max_seq_len)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.apm.conv1.weight.dtype,
|
||||
device=self.apm.conv1.weight.device)
|
||||
|
||||
if chunk_length > 0:
|
||||
chunk_num_frame = int(chunk_length * 50)
|
||||
chunk_mask = self.subsequent_chunk_mask(
|
||||
size=max_seq_len,
|
||||
chunk_size=chunk_num_frame,
|
||||
num_left_chunks=-1,
|
||||
device=audio_attention_mask_.device,
|
||||
)
|
||||
audio_attention_mask_ = torch.logical_or(
|
||||
audio_attention_mask_, torch.logical_not(chunk_mask))
|
||||
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
audio_states = self.apm(
|
||||
wavforms, attention_mask=audio_attention_mask).hidden_states[
|
||||
self.audio_encoder_layer]
|
||||
audio_embeds = self.audio_projection_layer(audio_states)
|
||||
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
|
||||
_, feature_lens_after_pooling = \
|
||||
self._get_feat_extract_output_lengths(audio_feature_lens)
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
final_audio_embeds = []
|
||||
idx = 0
|
||||
for i in range(len(audio_feature_lens_raw)):
|
||||
target_audio_embeds = []
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
idx += 1
|
||||
final_audio_embeds.append(target_audio_embeds)
|
||||
return final_audio_embeds
|
||||
|
||||
def get_embedding_with_audios(self, vlm_embedding: torch.Tensor,
|
||||
audio_inputs: Optional[MiniCPMOAudioInputs],
|
||||
audio_inputs: MiniCPMOAudioInputs,
|
||||
chunk_length: int) -> torch.Tensor:
|
||||
device, dtype = vlm_embedding.device, vlm_embedding.dtype
|
||||
if audio_inputs["type"] == "audio_embeds":
|
||||
audio_embeddings = audio_inputs["data"]
|
||||
audio_embeddings = [
|
||||
audio_embeddings[i].to(device=device, dtype=dtype)
|
||||
for i in range(len(audio_embeddings))
|
||||
item.to(device=device, dtype=dtype)
|
||||
for item in audio_inputs["audio_embeds"]
|
||||
]
|
||||
else:
|
||||
audio_embeddings = self.get_audio_hidden_states(
|
||||
@@ -746,40 +688,56 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
def _parse_and_validate_audio_inputs(
|
||||
self, input_ids: torch.Tensor,
|
||||
**kwargs: object) -> Tuple[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", [])
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens", [])
|
||||
**kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
audio_start_id = kwargs.pop("audio_start_id", None)
|
||||
audio_end_id = kwargs.pop("audio_end_id", None)
|
||||
|
||||
if audio_features is None and audio_embeds is None:
|
||||
return None
|
||||
|
||||
audio_start_id = kwargs.pop("audio_start_id")
|
||||
if not isinstance(audio_start_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_start_id. "
|
||||
f"Got type: {type(audio_start_id)}")
|
||||
|
||||
audio_end_id = kwargs.pop("audio_end_id")
|
||||
if not isinstance(audio_end_id, torch.Tensor):
|
||||
raise ValueError("Incorrect type of audio_end_id. "
|
||||
f"Got type: {type(audio_end_id)}")
|
||||
|
||||
if audio_embeds is not None:
|
||||
audio_embeds = [
|
||||
audio_embeds[i][j] for i in range(len(audio_embeds))
|
||||
for j in range(len(audio_embeds[i]))
|
||||
]
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
|
||||
return MiniCPMOAudioEmbeddingInputs(
|
||||
type="audio_embeds",
|
||||
audio_embeds=flatten_bn(flatten_2d_lists(audio_embeds),
|
||||
concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
data=audio_embeds,
|
||||
type="audio_embeds")
|
||||
if len(audio_features) > 0:
|
||||
audio_features_all = [
|
||||
i.permute(1, 0) for audio_feature in audio_features
|
||||
for i in audio_feature
|
||||
]
|
||||
audio_features = torch.nn.utils.rnn.pad_sequence(
|
||||
audio_features_all, batch_first=True,
|
||||
padding_value=0.0).permute(0, 2, 1)
|
||||
audio_feature_lens = torch.cat(
|
||||
[item for item in audio_feature_lens])
|
||||
)
|
||||
|
||||
if audio_features is not None:
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}")
|
||||
|
||||
return MiniCPMOAudioFeatureInputs(
|
||||
type="audio_features",
|
||||
audio_features=flatten_bn(audio_features, concat=True),
|
||||
audio_feature_lens=flatten_bn(
|
||||
flatten_2d_lists(audio_feature_lens), concat=True),
|
||||
audio_bounds=self._get_audio_bounds(input_ids, audio_start_id,
|
||||
audio_end_id),
|
||||
data=audio_features,
|
||||
audio_feature_lens=audio_feature_lens,
|
||||
type="audio_features")
|
||||
return None
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_inputs(self, input_ids: torch.Tensor,
|
||||
**kwargs: object):
|
||||
@@ -803,7 +761,7 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
else:
|
||||
image_inputs, audio_inputs = \
|
||||
self._parse_and_validate_inputs(input_ids, **kwargs)
|
||||
vlm_embeddings, _ = self.get_embedding_with_vision(
|
||||
vlm_embeddings = self.get_embedding_with_vision(
|
||||
input_ids, image_inputs)
|
||||
|
||||
if audio_inputs is not None:
|
||||
|
||||
Reference in New Issue
Block a user