Fix AudioFlamingo3/MusicFlamingo HF parity and RoTE handling (#37643)
Signed-off-by: Lasha <26011196+lashahub@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
43877a620b
commit
e7767eccae
@@ -69,10 +69,7 @@ from .utils import (
|
||||
maybe_prefix,
|
||||
)
|
||||
|
||||
MAX_AUDIO_LEN = 10 * 60
|
||||
|
||||
|
||||
# === Audio Inputs === #
|
||||
class AudioFlamingo3FeatureInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
@@ -127,14 +124,12 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
|
||||
):
|
||||
super().__init__(config)
|
||||
self.avg_pooler = nn.AvgPool1d(kernel_size=2, stride=2)
|
||||
# self.layer_norm is already initialized in super().__init__
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_features: torch.Tensor | list[torch.Tensor],
|
||||
attention_mask: torch.Tensor = None,
|
||||
):
|
||||
# input_features: (batch, num_mel_bins, seq_len)
|
||||
if isinstance(input_features, list):
|
||||
input_features = torch.stack(input_features)
|
||||
|
||||
@@ -146,17 +141,14 @@ class AudioFlamingo3Encoder(Qwen2AudioEncoder):
|
||||
).to(hidden_states.dtype)
|
||||
|
||||
for layer in self.layers:
|
||||
# Qwen2AudioEncoderLayer expects layer_head_mask as third arg.
|
||||
layer_outputs = layer(hidden_states, attention_mask, None)
|
||||
hidden_states = layer_outputs[0]
|
||||
layer_outputs = layer(hidden_states, attention_mask)
|
||||
hidden_states = (
|
||||
layer_outputs[0] if isinstance(layer_outputs, tuple) else layer_outputs
|
||||
)
|
||||
|
||||
# AvgPool (time/2) + LayerNorm
|
||||
# hidden_states: (batch, seq_len, hidden_size)
|
||||
hidden_states = hidden_states.permute(0, 2, 1) # (batch, hidden_size, seq_len)
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
hidden_states = self.avg_pooler(hidden_states)
|
||||
hidden_states = hidden_states.permute(
|
||||
0, 2, 1
|
||||
) # (batch, seq_len/2, hidden_size)
|
||||
hidden_states = hidden_states.permute(0, 2, 1)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -193,22 +185,6 @@ class AudioFlamingo3MultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[Any],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"audio_embeds"},
|
||||
fields_factory=_audioflamingo3_field_config,
|
||||
)
|
||||
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(AudioFlamingo3Config)
|
||||
@@ -217,20 +193,17 @@ class AudioFlamingo3ProcessingInfo(BaseProcessingInfo):
|
||||
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object):
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
feature_extractor = hf_processor.feature_extractor
|
||||
return feature_extractor
|
||||
return self.get_hf_processor(**kwargs).feature_extractor
|
||||
|
||||
def get_data_parser(self):
|
||||
def get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
return AudioFlamingo3MultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"audio": 1}
|
||||
return {"audio": None}
|
||||
|
||||
|
||||
class AudioFlamingo3DummyInputsBuilder(
|
||||
@@ -248,9 +221,10 @@ class AudioFlamingo3DummyInputsBuilder(
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> MultiModalDataDict:
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = MAX_AUDIO_LEN * sampling_rate
|
||||
audio_len = int(hf_processor.max_audio_len * sampling_rate)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_overrides = mm_options.get("audio")
|
||||
|
||||
@@ -284,6 +258,118 @@ def _audioflamingo3_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
)
|
||||
|
||||
|
||||
def _get_audio_post_pool_output_lengths(input_lengths: torch.Tensor) -> torch.Tensor:
|
||||
conv_lengths = (input_lengths - 1) // 2 + 1
|
||||
return (conv_lengths - 2) // 2 + 1
|
||||
|
||||
|
||||
def _build_audio_encoder_attention_mask(
|
||||
feature_attention_mask: torch.Tensor,
|
||||
*,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
) -> torch.Tensor:
|
||||
input_lengths = feature_attention_mask.sum(-1).to(torch.long)
|
||||
conv_lengths = (input_lengths - 1) // 2 + 1
|
||||
|
||||
batch_size, max_mel_seq_len = feature_attention_mask.shape
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
seq_range = (
|
||||
torch.arange(
|
||||
max_seq_len,
|
||||
dtype=conv_lengths.dtype,
|
||||
device=conv_lengths.device,
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, max_seq_len)
|
||||
)
|
||||
padding_mask = seq_range >= conv_lengths[:, None]
|
||||
|
||||
attention_mask = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
|
||||
batch_size, 1, max_seq_len, max_seq_len
|
||||
)
|
||||
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
||||
attention_mask.masked_fill_(padding_mask[:, None, None, :], float("-inf"))
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def _flatten_valid_audio_embeddings(
|
||||
audio_embeddings: torch.Tensor,
|
||||
feature_attention_mask: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
input_lengths = feature_attention_mask.sum(-1).to(torch.long)
|
||||
output_lengths = _get_audio_post_pool_output_lengths(input_lengths)
|
||||
valid_mask = (
|
||||
torch.arange(audio_embeddings.shape[1], device=output_lengths.device)[None, :]
|
||||
< output_lengths[:, None]
|
||||
)
|
||||
|
||||
return audio_embeddings[valid_mask], output_lengths
|
||||
|
||||
|
||||
def _count_audio_tokens_from_mask(
|
||||
feature_attention_mask: torch.Tensor | list[torch.Tensor],
|
||||
chunk_counts: torch.Tensor | list[torch.Tensor] | list[int] | None,
|
||||
item_idx: int,
|
||||
) -> int:
|
||||
if chunk_counts is not None:
|
||||
if isinstance(chunk_counts, torch.Tensor):
|
||||
counts = chunk_counts.tolist()
|
||||
elif chunk_counts and isinstance(chunk_counts[0], torch.Tensor):
|
||||
counts = [count.item() for count in chunk_counts]
|
||||
else:
|
||||
counts = chunk_counts
|
||||
|
||||
start_idx = sum(counts[:item_idx])
|
||||
count = counts[item_idx]
|
||||
end_idx = start_idx + count
|
||||
|
||||
if isinstance(feature_attention_mask, list):
|
||||
sample_mask = feature_attention_mask[start_idx:end_idx]
|
||||
if len(sample_mask) == 0:
|
||||
raise ValueError("Expected non-empty audio mask slice.")
|
||||
if isinstance(sample_mask[0], torch.Tensor):
|
||||
sample_mask = torch.stack(sample_mask)
|
||||
else:
|
||||
sample_mask = torch.tensor(sample_mask)
|
||||
else:
|
||||
sample_mask = feature_attention_mask[start_idx:end_idx]
|
||||
else:
|
||||
if isinstance(feature_attention_mask, list):
|
||||
sample_mask = feature_attention_mask[item_idx]
|
||||
else:
|
||||
sample_mask = feature_attention_mask[item_idx]
|
||||
|
||||
if sample_mask.ndim == 1:
|
||||
sample_input_lengths = sample_mask.sum().unsqueeze(0)
|
||||
else:
|
||||
# Match the HF processor, which derives placeholder lengths from the
|
||||
# total pre-encoder feature length for each original audio sample.
|
||||
sample_input_lengths = sample_mask.sum().reshape(1)
|
||||
|
||||
post_lengths = _get_audio_post_pool_output_lengths(
|
||||
sample_input_lengths.to(torch.long)
|
||||
)
|
||||
return int(post_lengths[0].item())
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalDataParser(MultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[Any],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"audio_embeds"},
|
||||
fields_factory=_audioflamingo3_field_config,
|
||||
)
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class AudioFlamingo3MultiModalProcessor(
|
||||
BaseMultiModalProcessor[AudioFlamingo3ProcessingInfo]
|
||||
):
|
||||
@@ -303,13 +389,13 @@ class AudioFlamingo3MultiModalProcessor(
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||
processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
feature_extractor = processor.feature_extractor
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
# Calculate chunk counts
|
||||
audio_list = mm_data.get("audio")
|
||||
if not isinstance(audio_list, list):
|
||||
audio_list = [audio_list]
|
||||
@@ -318,8 +404,7 @@ class AudioFlamingo3MultiModalProcessor(
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
chunk_length = feature_extractor.chunk_length
|
||||
window_size = int(sampling_rate * chunk_length)
|
||||
# MAX_AUDIO_LEN is 10 * 60 in HF processor.
|
||||
max_windows = int(MAX_AUDIO_LEN // chunk_length)
|
||||
max_windows = int(processor.max_audio_len // chunk_length)
|
||||
|
||||
for audio in audio_list:
|
||||
# audio is numpy array or list
|
||||
@@ -364,7 +449,6 @@ class AudioFlamingo3MultiModalProcessor(
|
||||
audio_token = getattr(processor, "audio_token", "<sound>")
|
||||
audio_token_id = vocab.get(audio_token)
|
||||
if audio_token_id is None:
|
||||
# Fallback if not found, though it should be there
|
||||
audio_token_id = processor.audio_token_id
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
@@ -373,38 +457,11 @@ class AudioFlamingo3MultiModalProcessor(
|
||||
|
||||
def get_replacement_audioflamingo3(item_idx: int):
|
||||
if feature_attention_mask is not None:
|
||||
if chunk_counts is not None:
|
||||
counts = (
|
||||
chunk_counts.tolist()
|
||||
if isinstance(chunk_counts, torch.Tensor)
|
||||
else chunk_counts
|
||||
)
|
||||
start_idx = sum(counts[:item_idx])
|
||||
count = counts[item_idx]
|
||||
end_idx = start_idx + count
|
||||
|
||||
if isinstance(feature_attention_mask, list):
|
||||
mask_list = feature_attention_mask[start_idx:end_idx]
|
||||
if len(mask_list) > 0 and isinstance(
|
||||
mask_list[0], torch.Tensor
|
||||
):
|
||||
mask = torch.stack(mask_list)
|
||||
else:
|
||||
mask = torch.tensor(mask_list)
|
||||
else:
|
||||
mask = feature_attention_mask[start_idx:end_idx]
|
||||
else:
|
||||
# feature_attention_mask is list[Tensor] or Tensor
|
||||
if isinstance(feature_attention_mask, list):
|
||||
mask = feature_attention_mask[item_idx]
|
||||
else:
|
||||
mask = feature_attention_mask[item_idx].unsqueeze(0)
|
||||
|
||||
# mask shape: (num_chunks, 3000)
|
||||
input_lengths = mask.sum(-1)
|
||||
conv_lengths = (input_lengths - 1) // 2 + 1
|
||||
audio_output_lengths = (conv_lengths - 2) // 2 + 1
|
||||
num_features = audio_output_lengths.sum().item()
|
||||
num_features = _count_audio_tokens_from_mask(
|
||||
feature_attention_mask,
|
||||
chunk_counts,
|
||||
item_idx,
|
||||
)
|
||||
else:
|
||||
audio_embeds = out_mm_data["audio_embeds"][item_idx]
|
||||
num_features = audio_embeds.shape[0]
|
||||
@@ -435,13 +492,6 @@ class AudioFlamingo3MultiModalProcessor(
|
||||
class AudioFlamingo3ForConditionalGeneration(
|
||||
nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA
|
||||
):
|
||||
"""
|
||||
AudioFlamingo3 model for conditional generation.
|
||||
|
||||
This model integrates a Whisper-based audio encoder with a Qwen2 language model.
|
||||
It supports multi-chunk audio processing.
|
||||
"""
|
||||
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
@@ -517,6 +567,25 @@ class AudioFlamingo3ForConditionalGeneration(
|
||||
audio_embeds = audio_input["audio_embeds"]
|
||||
return tuple(audio_embeds)
|
||||
|
||||
(
|
||||
input_features,
|
||||
feature_attention_mask,
|
||||
chunk_counts,
|
||||
) = self._normalize_audio_feature_inputs(audio_input)
|
||||
audio_hidden_states = self._encode_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask,
|
||||
)
|
||||
audio_features = self.multi_modal_projector(audio_hidden_states)
|
||||
return self._group_audio_embeddings(
|
||||
audio_features,
|
||||
feature_attention_mask,
|
||||
chunk_counts,
|
||||
)
|
||||
|
||||
def _normalize_audio_feature_inputs(
|
||||
self, audio_input: AudioFlamingo3FeatureInputs
|
||||
) -> tuple[torch.Tensor, torch.Tensor, list[int]]:
|
||||
input_features = audio_input["input_features"]
|
||||
feature_attention_mask = audio_input["feature_attention_mask"]
|
||||
chunk_counts = audio_input.get("chunk_counts")
|
||||
@@ -534,66 +603,36 @@ class AudioFlamingo3ForConditionalGeneration(
|
||||
and chunk_counts
|
||||
and isinstance(chunk_counts[0], torch.Tensor)
|
||||
):
|
||||
chunk_counts = [c.item() for c in chunk_counts]
|
||||
chunk_counts = [count.item() for count in chunk_counts]
|
||||
|
||||
# Calculate output lengths
|
||||
input_lengths = feature_attention_mask.sum(-1)
|
||||
# Conv downsampling
|
||||
conv_lengths = (input_lengths - 1) // 2 + 1
|
||||
# AvgPool downsampling
|
||||
audio_output_lengths = (conv_lengths - 2) // 2 + 1
|
||||
return input_features, feature_attention_mask, chunk_counts
|
||||
|
||||
batch_size, _, max_mel_seq_len = input_features.shape
|
||||
|
||||
# Calculate max_seq_len after convs (before pooling) for attention mask
|
||||
max_seq_len = (max_mel_seq_len - 1) // 2 + 1
|
||||
|
||||
# Create a sequence tensor of shape (batch_size, max_seq_len)
|
||||
seq_range = (
|
||||
torch.arange(
|
||||
0,
|
||||
max_seq_len,
|
||||
dtype=conv_lengths.dtype,
|
||||
device=conv_lengths.device,
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand(batch_size, max_seq_len)
|
||||
)
|
||||
lengths_expand = conv_lengths.unsqueeze(-1).expand(batch_size, max_seq_len)
|
||||
# Create mask
|
||||
padding_mask = seq_range >= lengths_expand
|
||||
|
||||
audio_attention_mask_ = padding_mask.view(batch_size, 1, 1, max_seq_len).expand(
|
||||
batch_size, 1, max_seq_len, max_seq_len
|
||||
)
|
||||
audio_attention_mask = audio_attention_mask_.to(
|
||||
def _encode_audio_features(
|
||||
self,
|
||||
input_features: torch.Tensor,
|
||||
feature_attention_mask: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
audio_attention_mask = _build_audio_encoder_attention_mask(
|
||||
feature_attention_mask,
|
||||
dtype=self.audio_tower.conv1.weight.dtype,
|
||||
device=self.audio_tower.conv1.weight.device,
|
||||
)
|
||||
audio_attention_mask[audio_attention_mask_] = float("-inf")
|
||||
|
||||
# Forward pass
|
||||
audio_features = self.audio_tower(
|
||||
input_features, attention_mask=audio_attention_mask
|
||||
return self.audio_tower(input_features, attention_mask=audio_attention_mask)
|
||||
|
||||
def _group_audio_embeddings(
|
||||
self,
|
||||
audio_features: torch.Tensor,
|
||||
feature_attention_mask: torch.Tensor,
|
||||
chunk_counts: list[int],
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
masked_audio_features, audio_output_lengths = _flatten_valid_audio_embeddings(
|
||||
audio_features,
|
||||
feature_attention_mask,
|
||||
)
|
||||
|
||||
# Project
|
||||
audio_features = self.multi_modal_projector(audio_features)
|
||||
|
||||
# Masking after pooling
|
||||
num_audios, max_audio_tokens, embed_dim = audio_features.shape
|
||||
audio_output_lengths = audio_output_lengths.unsqueeze(1)
|
||||
audio_features_mask = (
|
||||
torch.arange(max_audio_tokens)
|
||||
.expand(num_audios, max_audio_tokens)
|
||||
.to(audio_output_lengths.device)
|
||||
< audio_output_lengths
|
||||
)
|
||||
masked_audio_features = audio_features[audio_features_mask].view(-1, embed_dim)
|
||||
|
||||
# Split to tuple of embeddings for individual audio input.
|
||||
chunk_embeddings = torch.split(
|
||||
masked_audio_features, audio_output_lengths.flatten().tolist()
|
||||
masked_audio_features,
|
||||
audio_output_lengths.tolist(),
|
||||
)
|
||||
|
||||
grouped_embeddings = []
|
||||
@@ -613,7 +652,7 @@ class AudioFlamingo3ForConditionalGeneration(
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor | None,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
|
||||
@@ -1,63 +1,209 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
"""MusicFlamingo model adapter.
|
||||
# Copyright 2026 The vLLM team.
|
||||
# Copyright 2026 NVIDIA CORPORATION and the HuggingFace Inc. team. All rights
|
||||
# reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
MusicFlamingo shares the AudioFlamingo3 architecture, so we reuse the same
|
||||
implementation and multimodal processor, while accepting MusicFlamingo config
|
||||
and processor classes when available.
|
||||
"""
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from math import pi
|
||||
from typing import Annotated, Any, Optional, TypeAlias
|
||||
|
||||
from collections.abc import Mapping
|
||||
|
||||
from transformers.models.audioflamingo3 import (
|
||||
AudioFlamingo3Config,
|
||||
AudioFlamingo3Processor,
|
||||
import torch
|
||||
from torch import Tensor, broadcast_tensors, nn
|
||||
from transformers import BatchFeature
|
||||
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
|
||||
from transformers.models.musicflamingo import (
|
||||
MusicFlamingoConfig,
|
||||
MusicFlamingoProcessor,
|
||||
)
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import BaseProcessingInfo
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalDataDict,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalKwargsItems,
|
||||
)
|
||||
from vllm.multimodal.parse import (
|
||||
DictEmbeddingItems,
|
||||
ModalityData,
|
||||
ModalityDataItems,
|
||||
MultiModalDataItems,
|
||||
MultiModalDataParser,
|
||||
)
|
||||
from vllm.multimodal.processing import (
|
||||
PromptReplacement,
|
||||
PromptUpdate,
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorShape
|
||||
|
||||
from .audioflamingo3 import (
|
||||
AudioFlamingo3DummyInputsBuilder,
|
||||
AudioFlamingo3EmbeddingInputs,
|
||||
AudioFlamingo3Encoder,
|
||||
AudioFlamingo3FeatureInputs,
|
||||
AudioFlamingo3ForConditionalGeneration,
|
||||
AudioFlamingo3MultiModalDataParser,
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
AudioFlamingo3MultiModalProjector,
|
||||
AudioFlamingo3ProcessingInfo,
|
||||
_audioflamingo3_field_config,
|
||||
_count_audio_tokens_from_mask,
|
||||
)
|
||||
|
||||
try:
|
||||
# Optional dependency: use MusicFlamingo classes when transformers provides them.
|
||||
from transformers.models.musicflamingo import (
|
||||
MusicFlamingoConfig,
|
||||
MusicFlamingoProcessor,
|
||||
)
|
||||
except Exception: # pragma: no cover - optional dependency
|
||||
MusicFlamingoConfig = None
|
||||
MusicFlamingoProcessor = None
|
||||
|
||||
def rotate_half(x):
|
||||
x = x.reshape(*x.shape[:-1], -1, 2)
|
||||
x1, x2 = x.unbind(dim=-1)
|
||||
x = torch.stack((-x2, x1), dim=-1)
|
||||
return x.flatten(-2)
|
||||
|
||||
|
||||
class MusicFlamingoProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
if MusicFlamingoConfig is None:
|
||||
return self.ctx.get_hf_config(AudioFlamingo3Config)
|
||||
return self.ctx.get_hf_config((MusicFlamingoConfig, AudioFlamingo3Config))
|
||||
|
||||
def get_hf_processor(self, **kwargs: object):
|
||||
if MusicFlamingoProcessor is None:
|
||||
return self.ctx.get_hf_processor(AudioFlamingo3Processor, **kwargs)
|
||||
# Tuple triggers AutoProcessor path and accepts either processor class.
|
||||
return self.ctx.get_hf_processor(
|
||||
(MusicFlamingoProcessor, AudioFlamingo3Processor), **kwargs
|
||||
def apply_rotary_time_emb(hidden_states, cos, sin):
|
||||
original_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float64)
|
||||
cos = cos.to(hidden_states)
|
||||
sin = sin.to(hidden_states)
|
||||
rot_dim = cos.shape[-1]
|
||||
if rot_dim > hidden_states.shape[-1]:
|
||||
raise ValueError(
|
||||
f"feature dimension {hidden_states.shape[-1]} is not of "
|
||||
f"sufficient size to rotate in all the positions {rot_dim}"
|
||||
)
|
||||
|
||||
def get_feature_extractor(self, **kwargs: object):
|
||||
hf_processor = self.get_hf_processor(**kwargs)
|
||||
return hf_processor.feature_extractor
|
||||
rotated = hidden_states[..., :rot_dim]
|
||||
passthrough = hidden_states[..., rot_dim:]
|
||||
rotated = (rotated * cos) + (rotate_half(rotated) * sin)
|
||||
return torch.cat((rotated, passthrough), dim=-1).to(original_dtype)
|
||||
|
||||
def get_data_parser(self):
|
||||
|
||||
class MusicFlamingoRotaryEmbedding(nn.Module):
|
||||
inv_freq: torch.Tensor
|
||||
|
||||
def __init__(self, config: MusicFlamingoConfig, device=None):
|
||||
super().__init__()
|
||||
self.max_seq_len_cached = config.max_position_embeddings
|
||||
self.original_max_seq_len = config.max_position_embeddings
|
||||
|
||||
self.config = config
|
||||
self.rope_type = self.config.rope_parameters["rope_type"]
|
||||
rope_init_fn: Callable = self.compute_default_rope_parameters
|
||||
if self.rope_type != "default":
|
||||
rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
|
||||
inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
|
||||
|
||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
||||
self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
|
||||
position_angles = self._compute_position_angles(self.inv_freq)
|
||||
self.register_buffer("position_angles", position_angles, persistent=False)
|
||||
|
||||
@staticmethod
|
||||
def compute_default_rope_parameters(
|
||||
config: MusicFlamingoConfig | None = None,
|
||||
device: Optional["torch.device"] = None,
|
||||
seq_len: int | None = None,
|
||||
) -> tuple["torch.Tensor", float]:
|
||||
del seq_len
|
||||
base = config.rope_parameters["rope_theta"]
|
||||
dim = getattr(config, "head_dim", None) or (
|
||||
config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
attention_factor = 1.0
|
||||
|
||||
inv_freq = 1.0 / (
|
||||
base
|
||||
** (
|
||||
torch.arange(0, dim, 2, dtype=torch.int64).to(
|
||||
device=device,
|
||||
dtype=torch.float,
|
||||
)
|
||||
/ dim
|
||||
)
|
||||
)
|
||||
return inv_freq, attention_factor
|
||||
|
||||
def _compute_position_angles(self, inv_freq):
|
||||
positions = torch.arange(
|
||||
int(self.max_seq_len_cached),
|
||||
device=inv_freq.device,
|
||||
dtype=inv_freq.dtype,
|
||||
)
|
||||
positions = positions / self.max_seq_len_cached * (2 * pi)
|
||||
position_angles = positions.unsqueeze(-1) * inv_freq
|
||||
position_angles = torch.repeat_interleave(position_angles, 2, dim=-1)
|
||||
return position_angles.to(dtype=inv_freq.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, timestamps: Tensor, seq_len: int) -> tuple[Tensor, Tensor]:
|
||||
batch_positions = torch.arange(
|
||||
timestamps.shape[0],
|
||||
device=self.inv_freq.device,
|
||||
dtype=self.inv_freq.dtype,
|
||||
)
|
||||
batch_positions = batch_positions / self.max_seq_len_cached
|
||||
batch_freqs = batch_positions.unsqueeze(-1) * self.inv_freq
|
||||
batch_freqs = torch.repeat_interleave(batch_freqs, 2, dim=-1)
|
||||
|
||||
batch_freqs = batch_freqs[:, None, :]
|
||||
time_freqs = self.position_angles[:seq_len][None, :, :]
|
||||
batch_freqs, time_freqs = broadcast_tensors(batch_freqs, time_freqs)
|
||||
freqs = torch.cat((batch_freqs, time_freqs), dim=-1)
|
||||
angle = (-timestamps * 2 * pi).to(freqs)
|
||||
freqs = freqs * angle.unsqueeze(-1)
|
||||
return freqs.cos(), freqs.sin()
|
||||
|
||||
|
||||
class MusicFlamingoFeatureInputs(AudioFlamingo3FeatureInputs):
|
||||
rote_timestamps: Annotated[
|
||||
torch.Tensor,
|
||||
TensorShape(
|
||||
"num_chunks",
|
||||
"num_audio_time_steps",
|
||||
dynamic_dims={"num_audio_time_steps"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
MusicFlamingoEmbeddingInputs = AudioFlamingo3EmbeddingInputs
|
||||
|
||||
MusicFlamingoInputs: TypeAlias = (
|
||||
MusicFlamingoFeatureInputs | MusicFlamingoEmbeddingInputs
|
||||
)
|
||||
|
||||
|
||||
class MusicFlamingoEncoder(AudioFlamingo3Encoder):
|
||||
pass
|
||||
|
||||
|
||||
class MusicFlamingoMultiModalProjector(AudioFlamingo3MultiModalProjector):
|
||||
pass
|
||||
|
||||
|
||||
class MusicFlamingoProcessingInfo(AudioFlamingo3ProcessingInfo):
|
||||
def get_hf_config(self) -> MusicFlamingoConfig:
|
||||
return self.ctx.get_hf_config(MusicFlamingoConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> MusicFlamingoProcessor:
|
||||
return self.ctx.get_hf_processor(MusicFlamingoProcessor, **kwargs)
|
||||
|
||||
def get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.get_feature_extractor()
|
||||
|
||||
return AudioFlamingo3MultiModalDataParser(
|
||||
return MusicFlamingoMultiModalDataParser(
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
@@ -67,13 +213,230 @@ class MusicFlamingoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
|
||||
class MusicFlamingoDummyInputsBuilder(AudioFlamingo3DummyInputsBuilder):
|
||||
pass
|
||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
return hf_processor.audio_token * num_audios
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> MultiModalDataDict:
|
||||
hf_processor = self.info.get_hf_processor()
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = int(hf_processor.max_audio_len * sampling_rate)
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
audio_overrides = mm_options.get("audio")
|
||||
|
||||
return {
|
||||
"audio": self._get_dummy_audios(
|
||||
length=audio_len,
|
||||
num_audios=num_audios,
|
||||
overrides=audio_overrides,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def _musicflamingo_field_config(hf_inputs: Mapping[str, torch.Tensor]):
|
||||
fields = dict(_audioflamingo3_field_config(hf_inputs))
|
||||
chunk_counts = hf_inputs.get("chunk_counts")
|
||||
if chunk_counts is not None:
|
||||
fields["rote_timestamps"] = MultiModalFieldConfig.flat_from_sizes(
|
||||
"audio", chunk_counts, dim=0
|
||||
)
|
||||
else:
|
||||
fields["rote_timestamps"] = MultiModalFieldConfig.batched("audio")
|
||||
return fields
|
||||
|
||||
|
||||
class MusicFlamingoMultiModalDataParser(AudioFlamingo3MultiModalDataParser):
|
||||
def _parse_audio_data(
|
||||
self,
|
||||
data: dict[str, torch.Tensor] | ModalityData[Any],
|
||||
) -> ModalityDataItems[Any, Any] | None:
|
||||
if isinstance(data, dict):
|
||||
return DictEmbeddingItems(
|
||||
data,
|
||||
modality="audio",
|
||||
required_fields={"audio_embeds"},
|
||||
fields_factory=_musicflamingo_field_config,
|
||||
)
|
||||
return super()._parse_audio_data(data)
|
||||
|
||||
|
||||
class MusicFlamingoMultiModalProcessor(AudioFlamingo3MultiModalProcessor):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: dict[str, object],
|
||||
mm_kwargs: Mapping[str, Any],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
audio_data = mm_data.get("audio")
|
||||
if audio_data is None:
|
||||
return outputs
|
||||
|
||||
audio_list = audio_data if isinstance(audio_data, list) else [audio_data]
|
||||
if len(audio_list) == 0:
|
||||
return outputs
|
||||
|
||||
processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
feature_extractor = processor.feature_extractor
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
chunk_length = feature_extractor.chunk_length
|
||||
window_size = int(sampling_rate * chunk_length)
|
||||
max_windows = int(processor.max_audio_len // chunk_length)
|
||||
|
||||
chunk_counts = []
|
||||
for audio in audio_list:
|
||||
n_samples = len(audio) if isinstance(audio, list) else audio.shape[0]
|
||||
n_win = max(1, (n_samples + window_size - 1) // window_size)
|
||||
chunk_counts.append(min(n_win, max_windows))
|
||||
outputs["chunk_counts"] = torch.tensor(chunk_counts, dtype=torch.long)
|
||||
|
||||
if "rote_timestamps" not in outputs:
|
||||
raise KeyError(
|
||||
"MusicFlamingoProcessor output must include `rote_timestamps`."
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return _musicflamingo_field_config(hf_inputs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
audio_token = processor.audio_token
|
||||
audio_token_id = vocab.get(audio_token, processor.audio_token_id)
|
||||
|
||||
audio_bos_token = processor.audio_bos_token
|
||||
audio_bos_token_id = vocab.get(audio_bos_token, processor.audio_bos_token_id)
|
||||
|
||||
audio_eos_token = processor.audio_eos_token
|
||||
audio_eos_token_id = vocab.get(audio_eos_token, processor.audio_eos_token_id)
|
||||
|
||||
out_mm_data = out_mm_kwargs.get_data()
|
||||
feature_attention_mask = out_mm_data.get("feature_attention_mask")
|
||||
chunk_counts = out_mm_data.get("chunk_counts")
|
||||
|
||||
def get_replacement_musicflamingo(item_idx: int):
|
||||
if feature_attention_mask is not None:
|
||||
num_features = _count_audio_tokens_from_mask(
|
||||
feature_attention_mask,
|
||||
chunk_counts,
|
||||
item_idx,
|
||||
)
|
||||
else:
|
||||
audio_embeds = out_mm_data["audio_embeds"][item_idx]
|
||||
num_features = audio_embeds.shape[0]
|
||||
|
||||
if num_features == 0:
|
||||
raise ValueError("Audio is too short")
|
||||
|
||||
full_tokens = [
|
||||
audio_bos_token_id,
|
||||
*([audio_token_id] * int(num_features)),
|
||||
audio_eos_token_id,
|
||||
]
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
full_tokens,
|
||||
embed_token_id=audio_token_id,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=audio_token,
|
||||
replacement=get_replacement_musicflamingo,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
AudioFlamingo3MultiModalProcessor,
|
||||
MusicFlamingoMultiModalProcessor,
|
||||
info=MusicFlamingoProcessingInfo,
|
||||
dummy_inputs=MusicFlamingoDummyInputsBuilder,
|
||||
)
|
||||
class MusicFlamingoForConditionalGeneration(AudioFlamingo3ForConditionalGeneration):
|
||||
"""MusicFlamingo model for conditional generation."""
|
||||
"""vLLM MusicFlamingo model aligned with HF modular_musicflamingo."""
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
self.audio_tower = MusicFlamingoEncoder(self.config.audio_config)
|
||||
self.multi_modal_projector = MusicFlamingoMultiModalProjector(self.config)
|
||||
self.pos_emb = MusicFlamingoRotaryEmbedding(self.config)
|
||||
|
||||
def _parse_and_validate_audio_input(
|
||||
self, **kwargs: object
|
||||
) -> MusicFlamingoInputs | None:
|
||||
rote_timestamps = kwargs.pop("rote_timestamps", None)
|
||||
audio_input = super()._parse_and_validate_audio_input(**kwargs)
|
||||
if audio_input is None or audio_input["type"] == "audio_embeds":
|
||||
return audio_input
|
||||
|
||||
return MusicFlamingoFeatureInputs(
|
||||
type="audio_features",
|
||||
input_features=audio_input["input_features"],
|
||||
feature_attention_mask=audio_input["feature_attention_mask"],
|
||||
chunk_counts=audio_input["chunk_counts"],
|
||||
rote_timestamps=rote_timestamps,
|
||||
)
|
||||
|
||||
def _process_audio_input(
|
||||
self, audio_input: MusicFlamingoInputs
|
||||
) -> torch.Tensor | tuple[torch.Tensor, ...]:
|
||||
if audio_input["type"] == "audio_embeds":
|
||||
return super()._process_audio_input(audio_input)
|
||||
|
||||
rote_timestamps = audio_input["rote_timestamps"]
|
||||
if rote_timestamps is None:
|
||||
raise ValueError(
|
||||
"MusicFlamingo audio feature inputs must include `rote_timestamps`."
|
||||
)
|
||||
if isinstance(rote_timestamps, list):
|
||||
rote_timestamps = torch.cat(rote_timestamps, dim=0)
|
||||
|
||||
(
|
||||
input_features,
|
||||
feature_attention_mask,
|
||||
chunk_counts,
|
||||
) = self._normalize_audio_feature_inputs(audio_input)
|
||||
hidden_states = self._encode_audio_features(
|
||||
input_features,
|
||||
feature_attention_mask,
|
||||
)
|
||||
cos, sin = self.pos_emb(
|
||||
rote_timestamps.to(hidden_states.device),
|
||||
seq_len=hidden_states.shape[-2],
|
||||
)
|
||||
hidden_states = apply_rotary_time_emb(hidden_states, cos, sin)
|
||||
audio_features = self.multi_modal_projector(hidden_states)
|
||||
|
||||
return self._group_audio_embeddings(
|
||||
audio_features,
|
||||
feature_attention_mask,
|
||||
chunk_counts,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user