Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -23,6 +23,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only MiniCPM-O model compatible with HuggingFace weights."""
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Any, Callable, Literal, Optional, Union
|
||||
|
||||
@@ -30,31 +31,47 @@ import torch
|
||||
from torch import nn
|
||||
from transformers import BatchFeature
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.whisper.modeling_whisper import (ACT2FN,
|
||||
WhisperAttention,
|
||||
WhisperConfig,
|
||||
WhisperEncoder)
|
||||
from transformers.models.whisper.modeling_whisper import (
|
||||
ACT2FN,
|
||||
WhisperAttention,
|
||||
WhisperConfig,
|
||||
WhisperEncoder,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (AudioItem, AudioProcessorItems,
|
||||
DictEmbeddingItems, ModalityData,
|
||||
ModalityDataItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (PromptReplacement, PromptUpdate,
|
||||
PromptUpdateDetails)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
AudioItem,
|
||||
AudioProcessorItems,
|
||||
DictEmbeddingItems,
|
||||
ModalityData,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .minicpmv import (_MAX_FRAMES_PER_VIDEO, MiniCPMV2_6,
|
||||
MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor, MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config)
|
||||
from .utils import (AutoWeightsLoader, cast_overflow_tensors, flatten_bn,
|
||||
maybe_prefix)
|
||||
from .minicpmv import (
|
||||
_MAX_FRAMES_PER_VIDEO,
|
||||
MiniCPMV2_6,
|
||||
MiniCPMVDummyInputsBuilder,
|
||||
MiniCPMVMultiModalDataParser,
|
||||
MiniCPMVMultiModalProcessor,
|
||||
MiniCPMVProcessingInfo,
|
||||
_minicpmv_field_config,
|
||||
)
|
||||
from .utils import AutoWeightsLoader, cast_overflow_tensors, flatten_bn, maybe_prefix
|
||||
|
||||
CPU_DEVICE = torch.device("cpu")
|
||||
|
||||
@@ -68,6 +85,7 @@ class MiniCPMOAudioFeatureInputs(TensorSchema):
|
||||
- l: Length
|
||||
- s: Number of slices
|
||||
"""
|
||||
|
||||
type: Literal["audio_features"] = "audio_features"
|
||||
|
||||
audio_features: Annotated[
|
||||
@@ -96,9 +114,10 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema):
|
||||
- bn: Batch size * number of audios
|
||||
- s: Number of slices
|
||||
- h: Hidden size (must match language model backbone)
|
||||
|
||||
|
||||
Length of each slice may vary, so pass it as a list.
|
||||
"""
|
||||
|
||||
type: Literal["audio_embeds"] = "audio_embeds"
|
||||
|
||||
audio_embeds: Annotated[
|
||||
@@ -107,8 +126,7 @@ class MiniCPMOAudioEmbeddingInputs(TensorSchema):
|
||||
]
|
||||
|
||||
|
||||
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs,
|
||||
MiniCPMOAudioEmbeddingInputs]
|
||||
MiniCPMOAudioInputs = Union[MiniCPMOAudioFeatureInputs, MiniCPMOAudioEmbeddingInputs]
|
||||
|
||||
|
||||
def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
@@ -125,7 +143,6 @@ def _minicpmo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
|
||||
|
||||
class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: Mapping[str, torch.Tensor],
|
||||
@@ -143,7 +160,6 @@ class MiniCPMOAudioEmbeddingItems(DictEmbeddingItems):
|
||||
|
||||
|
||||
class MiniCPMOMultiModalDataParser(MiniCPMVMultiModalDataParser):
|
||||
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: Union[dict[str, torch.Tensor], ModalityData[AudioItem]],
|
||||
@@ -215,18 +231,17 @@ class MiniCPMOProcessingInfo(MiniCPMVProcessingInfo):
|
||||
|
||||
max_image_tokens = self.get_max_image_tokens() * max_images
|
||||
max_audio_tokens = self.get_max_audio_tokens() * max_audios
|
||||
max_total_frames = self.get_max_video_frames(seq_len -
|
||||
max_image_tokens -
|
||||
max_audio_tokens)
|
||||
max_frames_per_video = min(max_total_frames // max(max_videos, 1),
|
||||
_MAX_FRAMES_PER_VIDEO)
|
||||
max_total_frames = self.get_max_video_frames(
|
||||
seq_len - max_image_tokens - max_audio_tokens
|
||||
)
|
||||
max_frames_per_video = min(
|
||||
max_total_frames // max(max_videos, 1), _MAX_FRAMES_PER_VIDEO
|
||||
)
|
||||
|
||||
return max(max_frames_per_video, 1)
|
||||
|
||||
|
||||
class MiniCPMODummyInputsBuilder(
|
||||
MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
|
||||
|
||||
class MiniCPMODummyInputsBuilder(MiniCPMVDummyInputsBuilder[MiniCPMOProcessingInfo]):
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
@@ -241,16 +256,17 @@ class MiniCPMODummyInputsBuilder(
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_len = self.info.get_max_audio_chunks_with_most_features() * \
|
||||
self.info.get_default_audio_sampling_rate()
|
||||
audio_len = (
|
||||
self.info.get_max_audio_chunks_with_most_features()
|
||||
* self.info.get_default_audio_sampling_rate()
|
||||
)
|
||||
|
||||
audio_overrides = mm_options.get("audio") if mm_options else None
|
||||
|
||||
audio_mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides)
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len, num_audios=num_audios, overrides=audio_overrides
|
||||
)
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -259,12 +275,11 @@ class MiniCPMODummyInputsBuilder(
|
||||
}
|
||||
|
||||
|
||||
class MiniCPMOMultiModalProcessor(
|
||||
MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
|
||||
|
||||
class MiniCPMOMultiModalProcessor(MiniCPMVMultiModalProcessor[MiniCPMOProcessingInfo]):
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
return MiniCPMOMultiModalDataParser(
|
||||
target_sr=self.info.get_default_audio_sampling_rate())
|
||||
target_sr=self.info.get_default_audio_sampling_rate()
|
||||
)
|
||||
|
||||
def get_audio_prompt_texts(
|
||||
self,
|
||||
@@ -287,10 +302,11 @@ class MiniCPMOMultiModalProcessor(
|
||||
if (audios := mm_data.get("audios")) is None:
|
||||
return {}
|
||||
|
||||
parsed_audios = (self._get_data_parser().parse_mm_data({
|
||||
"audio": audios
|
||||
}).get_items("audio",
|
||||
(MiniCPMOAudioEmbeddingItems, AudioProcessorItems)))
|
||||
parsed_audios = (
|
||||
self._get_data_parser()
|
||||
.parse_mm_data({"audio": audios})
|
||||
.get_items("audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))
|
||||
)
|
||||
|
||||
if isinstance(parsed_audios, MiniCPMOAudioEmbeddingItems):
|
||||
audio_inputs = {}
|
||||
@@ -298,9 +314,7 @@ class MiniCPMOMultiModalProcessor(
|
||||
audio_inputs = self._base_call_hf_processor(
|
||||
prompts=[self.info.audio_pattern] * len(parsed_audios),
|
||||
mm_data={"audios": [[audio] for audio in parsed_audios]},
|
||||
mm_kwargs={
|
||||
**mm_kwargs, "chunk_input": True
|
||||
},
|
||||
mm_kwargs={**mm_kwargs, "chunk_input": True},
|
||||
tok_kwargs=tok_kwargs,
|
||||
out_keys={"audio_features", "audio_feature_lens"},
|
||||
)
|
||||
@@ -308,7 +322,8 @@ class MiniCPMOMultiModalProcessor(
|
||||
# Avoid padding since we need the output for each audio to be
|
||||
# independent of other audios for the cache to work correctly
|
||||
unpadded_audio_features = [
|
||||
feat[:, :feature_len] for feat, feature_len in zip(
|
||||
feat[:, :feature_len]
|
||||
for feat, feature_len in zip(
|
||||
audio_inputs["audio_features"],
|
||||
audio_inputs["audio_feature_lens"],
|
||||
)
|
||||
@@ -348,12 +363,14 @@ class MiniCPMOMultiModalProcessor(
|
||||
|
||||
def get_audio_replacement(item_idx: int):
|
||||
audios = mm_items.get_items(
|
||||
"audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems))
|
||||
"audio", (MiniCPMOAudioEmbeddingItems, AudioProcessorItems)
|
||||
)
|
||||
|
||||
if isinstance(audios, MiniCPMOAudioEmbeddingItems):
|
||||
single_audio_embeds = audios.get(item_idx)["audio_embeds"]
|
||||
audio_len = self.info.get_audio_len_by_num_chunks(
|
||||
sum(map(len, single_audio_embeds)))
|
||||
sum(map(len, single_audio_embeds))
|
||||
)
|
||||
else:
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
@@ -364,9 +381,11 @@ class MiniCPMOMultiModalProcessor(
|
||||
|
||||
return [
|
||||
*base_updates,
|
||||
PromptReplacement(modality="audio",
|
||||
target=audio_placeholder,
|
||||
replacement=get_audio_replacement),
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_placeholder,
|
||||
replacement=get_audio_replacement,
|
||||
),
|
||||
]
|
||||
|
||||
def _get_mm_fields_config(
|
||||
@@ -378,16 +397,11 @@ class MiniCPMOMultiModalProcessor(
|
||||
|
||||
|
||||
class MultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, in_dim: int, out_dim: int):
|
||||
super().__init__()
|
||||
self.linear1 = nn.Linear(in_features=in_dim,
|
||||
out_features=out_dim,
|
||||
bias=True)
|
||||
self.linear1 = nn.Linear(in_features=in_dim, out_features=out_dim, bias=True)
|
||||
self.relu = nn.ReLU()
|
||||
self.linear2 = nn.Linear(in_features=out_dim,
|
||||
out_features=out_dim,
|
||||
bias=True)
|
||||
self.linear2 = nn.Linear(in_features=out_dim, out_features=out_dim, bias=True)
|
||||
|
||||
def forward(self, audio_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.relu(self.linear1(audio_features))
|
||||
@@ -396,7 +410,6 @@ class MultiModalProjector(nn.Module):
|
||||
|
||||
|
||||
class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self, config: WhisperConfig, layer_idx: int):
|
||||
super().__init__()
|
||||
self.embed_dim = config.d_model
|
||||
@@ -428,39 +441,40 @@ class MiniCPMWhisperEncoderLayer(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
past_key_value=past_key_values,
|
||||
)
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.dropout, training=self.training
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.activation_dropout,
|
||||
training=self.training)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.activation_dropout, training=self.training
|
||||
)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.dropout, training=self.training
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
hidden_states = cast_overflow_tensors(hidden_states)
|
||||
|
||||
outputs = (hidden_states, )
|
||||
outputs = (hidden_states,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MiniCPMWhisperEncoder(WhisperEncoder):
|
||||
|
||||
def __init__(self, config: WhisperConfig):
|
||||
super().__init__(config)
|
||||
self.layers = nn.ModuleList([
|
||||
MiniCPMWhisperEncoderLayer(config, layer_idx=i)
|
||||
for i in range(config.encoder_layers)
|
||||
])
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
MiniCPMWhisperEncoderLayer(config, layer_idx=i)
|
||||
for i in range(config.encoder_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -468,8 +482,9 @@ class MiniCPMWhisperEncoder(WhisperEncoder):
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
) -> BaseModelOutputWithPast:
|
||||
# Ignore copy
|
||||
input_features = input_features.to(dtype=self.conv1.weight.dtype,
|
||||
device=self.conv1.weight.device)
|
||||
input_features = input_features.to(
|
||||
dtype=self.conv1.weight.dtype, device=self.conv1.weight.device
|
||||
)
|
||||
|
||||
inputs_embeds = nn.functional.gelu(self.conv1(input_features))
|
||||
inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds))
|
||||
@@ -478,17 +493,17 @@ class MiniCPMWhisperEncoder(WhisperEncoder):
|
||||
|
||||
embed_pos = self.embed_positions.weight
|
||||
|
||||
embed_pos = embed_pos[:inputs_embeds.shape[1], :]
|
||||
embed_pos = embed_pos[: inputs_embeds.shape[1], :]
|
||||
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
hidden_states = nn.functional.dropout(hidden_states,
|
||||
p=self.dropout,
|
||||
training=self.training)
|
||||
hidden_states = nn.functional.dropout(
|
||||
hidden_states, p=self.dropout, training=self.training
|
||||
)
|
||||
|
||||
encoder_states = ()
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
encoder_states = encoder_states + (hidden_states, )
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
to_drop = False
|
||||
if self.training:
|
||||
dropout_probability = torch.rand([])
|
||||
@@ -507,7 +522,7 @@ class MiniCPMWhisperEncoder(WhisperEncoder):
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
encoder_states = encoder_states + (hidden_states, )
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
@@ -518,7 +533,8 @@ class MiniCPMWhisperEncoder(WhisperEncoder):
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
MiniCPMOMultiModalProcessor,
|
||||
info=MiniCPMOProcessingInfo,
|
||||
dummy_inputs=MiniCPMODummyInputsBuilder)
|
||||
dummy_inputs=MiniCPMODummyInputsBuilder,
|
||||
)
|
||||
class MiniCPMO(MiniCPMV2_6):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
@@ -545,8 +561,9 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.apm = self.init_audio_module(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "apm"))
|
||||
self.apm = self.init_audio_module(
|
||||
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "apm")
|
||||
)
|
||||
|
||||
self.audio_token_id = None
|
||||
|
||||
@@ -555,16 +572,16 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
audio_config = self.config.audio_config
|
||||
model = MiniCPMWhisperEncoder(audio_config)
|
||||
audio_output_dim = int(audio_config.encoder_ffn_dim // 4)
|
||||
self.audio_avg_pooler = \
|
||||
nn.AvgPool1d(self.config.audio_pool_step,
|
||||
stride=self.config.audio_pool_step)
|
||||
self.audio_projection_layer = \
|
||||
MultiModalProjector(in_dim=audio_output_dim,out_dim=self.embed_dim)
|
||||
self.audio_avg_pooler = nn.AvgPool1d(
|
||||
self.config.audio_pool_step, stride=self.config.audio_pool_step
|
||||
)
|
||||
self.audio_projection_layer = MultiModalProjector(
|
||||
in_dim=audio_output_dim, out_dim=self.embed_dim
|
||||
)
|
||||
self.audio_encoder_layer = -1
|
||||
return model
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(self, skip_prefixes=["tts"])
|
||||
return loader.load_weights(weights)
|
||||
|
||||
@@ -585,14 +602,13 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
start_indices = torch.zeros_like(row_indices)
|
||||
else:
|
||||
# Compute start indices vectorially
|
||||
start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks,
|
||||
min=0)
|
||||
start_chunk_indices = torch.clamp(chunk_indices - num_left_chunks, min=0)
|
||||
start_indices = start_chunk_indices * chunk_size
|
||||
# Compute ending indices vectorially
|
||||
end_chunk_indices = chunk_indices + 1
|
||||
end_indices = torch.clamp(end_chunk_indices * chunk_size +
|
||||
num_lookhead,
|
||||
max=size)
|
||||
end_indices = torch.clamp(
|
||||
end_chunk_indices * chunk_size + num_lookhead, max=size
|
||||
)
|
||||
# Create column indices for broadcasting
|
||||
col_indices = torch.arange(size, device=device).unsqueeze(0)
|
||||
start_indices = start_indices.unsqueeze(1)
|
||||
@@ -601,19 +617,18 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
ret = (col_indices >= start_indices) & (col_indices < end_indices)
|
||||
return ret
|
||||
|
||||
def _get_feat_extract_output_lengths(self,
|
||||
input_lengths: torch.LongTensor):
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
input_lengths_after_cnn = (input_lengths - 1) // 2 + 1
|
||||
input_lengths_after_pooling = (
|
||||
input_lengths_after_cnn -
|
||||
self.config.audio_pool_step) // self.config.audio_pool_step + 1
|
||||
input_lengths_after_pooling = input_lengths_after_pooling.to(
|
||||
dtype=torch.int32)
|
||||
input_lengths_after_cnn - self.config.audio_pool_step
|
||||
) // self.config.audio_pool_step + 1
|
||||
input_lengths_after_pooling = input_lengths_after_pooling.to(dtype=torch.int32)
|
||||
|
||||
return input_lengths_after_cnn, input_lengths_after_pooling
|
||||
|
||||
def get_audio_hidden_states(
|
||||
self, data: MiniCPMOAudioFeatureInputs) -> list[torch.Tensor]:
|
||||
self, data: MiniCPMOAudioFeatureInputs
|
||||
) -> list[torch.Tensor]:
|
||||
chunk_length = self.config.audio_chunk_length
|
||||
|
||||
# (bs, 80, frames) or [], multi audios need filled in advance
|
||||
@@ -642,23 +657,26 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feature_lens.dtype,
|
||||
device=audio_feature_lens.device).unsqueeze(0).expand(
|
||||
batch_size, max_seq_len))
|
||||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(
|
||||
batch_size, max_seq_len)
|
||||
seq_range = (
|
||||
torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=audio_feature_lens.dtype,
|
||||
device=audio_feature_lens.device,
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, max_seq_len)
|
||||
)
|
||||
lengths_expand = audio_feature_lens.unsqueeze(1).expand(batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand # 1 for padded values
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(
|
||||
batch_size, 1, 1, max_seq_len).expand(batch_size, 1, max_seq_len,
|
||||
max_seq_len)
|
||||
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
|
||||
batch_size, 1, max_seq_len, max_seq_len
|
||||
)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
dtype=self.apm.conv1.weight.dtype,
|
||||
device=self.apm.conv1.weight.device)
|
||||
dtype=self.apm.conv1.weight.dtype, device=self.apm.conv1.weight.device
|
||||
)
|
||||
|
||||
if chunk_length > 0:
|
||||
chunk_num_frame = int(chunk_length * 50)
|
||||
@@ -669,20 +687,22 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
device=audio_attention_mask_.device,
|
||||
)
|
||||
audio_attention_mask_ = torch.logical_or(
|
||||
audio_attention_mask_, torch.logical_not(chunk_mask))
|
||||
audio_attention_mask_, torch.logical_not(chunk_mask)
|
||||
)
|
||||
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
audio_states = self.apm(
|
||||
wavforms, attention_mask=audio_attention_mask).hidden_states[
|
||||
self.audio_encoder_layer]
|
||||
wavforms, attention_mask=audio_attention_mask
|
||||
).hidden_states[self.audio_encoder_layer]
|
||||
audio_embeds = self.audio_projection_layer(audio_states)
|
||||
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
audio_embeds = self.audio_avg_pooler(audio_embeds)
|
||||
audio_embeds = audio_embeds.transpose(1, 2)
|
||||
|
||||
_, feature_lens_after_pooling = \
|
||||
self._get_feat_extract_output_lengths(audio_feature_lens)
|
||||
_, feature_lens_after_pooling = self._get_feat_extract_output_lengths(
|
||||
audio_feature_lens
|
||||
)
|
||||
|
||||
num_audio_tokens = feature_lens_after_pooling
|
||||
|
||||
@@ -692,7 +712,8 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
target_audio_embeds_lst = list[torch.Tensor]()
|
||||
for _ in range(len(audio_feature_lens_raw[i])):
|
||||
target_audio_embeds_lst.append(
|
||||
audio_embeds[idx, :num_audio_tokens[idx], :])
|
||||
audio_embeds[idx, : num_audio_tokens[idx], :]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
final_audio_embeds.append(torch.cat(target_audio_embeds_lst))
|
||||
@@ -700,7 +721,8 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
return final_audio_embeds
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object) -> Optional[MiniCPMOAudioInputs]:
|
||||
self, **kwargs: object
|
||||
) -> Optional[MiniCPMOAudioInputs]:
|
||||
audio_features = kwargs.pop("audio_features", None)
|
||||
audio_embeds = kwargs.pop("audio_embeds", None)
|
||||
|
||||
@@ -714,8 +736,9 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
|
||||
if audio_embeds is not None:
|
||||
if not isinstance(audio_embeds, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_embeds. "
|
||||
f"Got type: {type(audio_embeds)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio_embeds. Got type: {type(audio_embeds)}"
|
||||
)
|
||||
|
||||
audio_embeds_flat = flatten_bn(audio_embeds)
|
||||
|
||||
@@ -725,13 +748,16 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
)
|
||||
|
||||
if not isinstance(audio_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_features. "
|
||||
f"Got type: {type(audio_features)}")
|
||||
raise ValueError(
|
||||
f"Incorrect type of audio_features. Got type: {type(audio_features)}"
|
||||
)
|
||||
|
||||
audio_feature_lens = kwargs.pop("audio_feature_lens")
|
||||
if not isinstance(audio_feature_lens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}")
|
||||
raise ValueError(
|
||||
"Incorrect type of audio_feature_lens. "
|
||||
f"Got type: {type(audio_feature_lens)}"
|
||||
)
|
||||
|
||||
audio_features_flat = flatten_bn(audio_features)
|
||||
audio_feature_lens_flat = flatten_bn(audio_feature_lens)
|
||||
@@ -748,10 +774,11 @@ class MiniCPMO(MiniCPMV2_6):
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("audio_features",
|
||||
"audio_embeds") and "audios" not in modalities:
|
||||
modalities["audios"] = self._parse_and_validate_audio_input(
|
||||
**kwargs)
|
||||
if (
|
||||
input_key in ("audio_features", "audio_embeds")
|
||||
and "audios" not in modalities
|
||||
):
|
||||
modalities["audios"] = self._parse_and_validate_audio_input(**kwargs)
|
||||
|
||||
return modalities
|
||||
|
||||
|
||||
Reference in New Issue
Block a user