1084 lines
41 KiB
Python
1084 lines
41 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
"""Inference-only FunAudioChat model compatible with HuggingFace weights.
|
|
|
|
FunAudioChat is a Qwen3 text model augmented with:
|
|
- a continuous audio encoder (Whisper-mel frontend + transformer)
|
|
- a discrete audio encoder (speech tokenizer + projector)
|
|
|
|
In the HF implementation, audio features are scattered into `<|AUDIO|>` token
|
|
positions via `inputs_embeds`, while `position_ids` (RoPE) remains standard 1D.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
from functools import cached_property
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from transformers import PreTrainedTokenizerFast, WhisperFeatureExtractor
|
|
from transformers.activations import get_activation
|
|
from transformers.feature_extraction_utils import BatchFeature
|
|
from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.config.multimodal import BaseDummyOptions
|
|
from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention
|
|
from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
|
from vllm.multimodal.inputs import (
|
|
MultiModalDataDict,
|
|
MultiModalFieldConfig,
|
|
MultiModalKwargsItems,
|
|
)
|
|
from vllm.multimodal.parse import (
|
|
AudioProcessorItems,
|
|
MultiModalDataItems,
|
|
MultiModalDataParser,
|
|
)
|
|
from vllm.multimodal.processing import (
|
|
BaseDummyInputsBuilder,
|
|
BaseMultiModalProcessor,
|
|
BaseProcessingInfo,
|
|
PromptReplacement,
|
|
PromptUpdate,
|
|
PromptUpdateDetails,
|
|
)
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.utils.import_utils import _has_module
|
|
|
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
|
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
|
|
|
|
|
|
class _SinusoidsPositionEmbedding(nn.Module):
|
|
def __init__(self, length: int, channels: int, max_timescale: float = 10000.0):
|
|
super().__init__()
|
|
if channels % 2 != 0:
|
|
raise ValueError("SinusoidsPositionEmbedding needs even channels input")
|
|
|
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
|
inv_timescales = torch.exp(
|
|
-log_timescale_increment * torch.arange(channels // 2).float()
|
|
)
|
|
scaled_time = (
|
|
torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
|
)
|
|
self.register_buffer(
|
|
"positional_embedding",
|
|
torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
|
|
persistent=False,
|
|
)
|
|
|
|
|
|
class FunAudioChatAudioAttention(nn.Module):
|
|
"""Multi-headed attention used inside the continuous audio tower."""
|
|
|
|
def __init__(self, config: Any):
|
|
super().__init__()
|
|
self.embed_dim = int(config.d_model)
|
|
self.total_num_heads = int(config.encoder_attention_heads)
|
|
self.dropout = float(getattr(config, "attention_dropout", 0.0))
|
|
self.head_dim = self.embed_dim // self.total_num_heads
|
|
self.num_key_value_groups = 1 # needed for eager attention
|
|
self.config = config
|
|
|
|
if self.head_dim * self.total_num_heads != self.embed_dim:
|
|
raise ValueError(
|
|
"embed_dim must be divisible by num_heads "
|
|
f"(got embed_dim={self.embed_dim}, "
|
|
f"num_heads={self.total_num_heads})."
|
|
)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.attention_dropout = 0.0
|
|
self.is_decoder = False
|
|
self.is_causal = False
|
|
|
|
self.qkv_proj = QKVParallelLinear(
|
|
self.embed_dim,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
bias=True,
|
|
)
|
|
self.num_heads = self.qkv_proj.num_heads
|
|
self.num_kv_heads = self.qkv_proj.num_kv_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
|
|
self.attn = MMEncoderAttention(
|
|
num_heads=self.num_heads,
|
|
head_size=self.head_dim,
|
|
scale=self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
prefix="funaudiochat_audio_tower.attn",
|
|
)
|
|
self.out_proj = RowParallelLinear(
|
|
self.embed_dim,
|
|
self.embed_dim,
|
|
bias=True,
|
|
)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
("qkv_proj", "q_proj", "q"),
|
|
("qkv_proj", "k_proj", "k"),
|
|
("qkv_proj", "v_proj", "v"),
|
|
]
|
|
|
|
params_dict = dict(self.named_parameters())
|
|
with torch.no_grad():
|
|
if self.qkv_proj.bias is not None:
|
|
# HF FunAudioChat uses bias=False for k_proj. Ensure the missing
|
|
# shard starts as zeros, while allowing q/v shards to load.
|
|
self.qkv_proj.bias.zero_()
|
|
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
for param_name, shard_name, shard_id in stacked_params_mapping:
|
|
if shard_name not in name:
|
|
continue
|
|
name = name.replace(shard_name, param_name)
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
|
|
loaded_params.add(name)
|
|
|
|
return loaded_params
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cu_seqlens: torch.Tensor | None = None,
|
|
attention_mask: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor:
|
|
del kwargs
|
|
del attention_mask
|
|
seq_length, _ = hidden_states.size()
|
|
|
|
qkv, _ = self.qkv_proj(hidden_states)
|
|
query_states, key_states, value_states = qkv.split(
|
|
[self.q_size, self.kv_size, self.kv_size], dim=-1
|
|
)
|
|
|
|
max_seqlen: torch.Tensor | None = None
|
|
if cu_seqlens is not None:
|
|
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
|
|
|
attn_output = self.attn(
|
|
query_states.reshape(1, seq_length, self.q_size),
|
|
key_states.reshape(1, seq_length, self.kv_size),
|
|
value_states.reshape(1, seq_length, self.kv_size),
|
|
cu_seqlens=cu_seqlens,
|
|
max_seqlen=max_seqlen,
|
|
).reshape(seq_length, -1)
|
|
|
|
output, _ = self.out_proj(attn_output)
|
|
return output
|
|
|
|
|
|
class FunAudioChatAudioEncoderLayer(nn.Module):
|
|
def __init__(self, config: Any):
|
|
super().__init__()
|
|
self.embed_dim = int(config.d_model)
|
|
self.self_attn = FunAudioChatAudioAttention(config)
|
|
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
self.dropout = float(config.dropout)
|
|
self.activation_fn = get_activation(str(config.activation_function))
|
|
self.activation_dropout = float(config.activation_dropout)
|
|
self.fc1 = nn.Linear(self.embed_dim, int(config.encoder_ffn_dim))
|
|
self.fc2 = nn.Linear(int(config.encoder_ffn_dim), self.embed_dim)
|
|
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
cu_seqlens: torch.Tensor,
|
|
attention_mask: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> tuple[torch.Tensor]:
|
|
residual = hidden_states
|
|
hidden_states = self.self_attn_layer_norm(hidden_states)
|
|
hidden_states = self.self_attn(
|
|
hidden_states=hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
attention_mask=attention_mask,
|
|
**kwargs,
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
residual = hidden_states
|
|
hidden_states = self.final_layer_norm(hidden_states)
|
|
hidden_states = self.activation_fn(self.fc1(hidden_states))
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.activation_dropout, training=self.training
|
|
)
|
|
hidden_states = self.fc2(hidden_states)
|
|
hidden_states = nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=self.training
|
|
)
|
|
hidden_states = residual + hidden_states
|
|
|
|
return (hidden_states,)
|
|
|
|
|
|
class FunAudioChatAudioEncoder(nn.Module):
|
|
"""Continuous audio tower."""
|
|
|
|
def __init__(self, config: Any):
|
|
super().__init__()
|
|
self.config = config
|
|
|
|
embed_dim = int(config.d_model)
|
|
self.num_mel_bins = int(config.num_mel_bins)
|
|
self.max_source_positions = int(config.max_source_positions)
|
|
self.embed_scale = (embed_dim**0.5) if bool(config.scale_embedding) else 1.0
|
|
self.n_window = int(config.n_window)
|
|
|
|
self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1)
|
|
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1)
|
|
self.layers = nn.ModuleList(
|
|
[
|
|
FunAudioChatAudioEncoderLayer(config)
|
|
for _ in range(int(config.encoder_layers))
|
|
]
|
|
)
|
|
self.ln_post = nn.LayerNorm(embed_dim)
|
|
self.avg_pooler = nn.AvgPool1d(2, stride=2)
|
|
self.proj = nn.Linear(embed_dim, int(config.output_dim))
|
|
self.positional_embedding = _SinusoidsPositionEmbedding(
|
|
self.max_source_positions, embed_dim
|
|
)
|
|
|
|
# Present in HF weights even if unused during S2T.
|
|
self.audio_bos_eos_token = nn.Embedding(2, int(config.output_dim))
|
|
|
|
@property
|
|
def dtype(self) -> torch.dtype:
|
|
return self.conv1.weight.dtype
|
|
|
|
def _prepare_attention_mask(
|
|
self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor
|
|
) -> torch.Tensor | None:
|
|
if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2":
|
|
return None
|
|
|
|
seq_length = inputs_tensor.shape[0]
|
|
attention_mask = torch.full(
|
|
(1, 1, seq_length, seq_length),
|
|
torch.finfo(inputs_tensor.dtype).min,
|
|
device=inputs_tensor.device,
|
|
dtype=inputs_tensor.dtype,
|
|
)
|
|
for i in range(1, len(cu_seqlens)):
|
|
start = int(cu_seqlens[i - 1].item())
|
|
end = int(cu_seqlens[i].item())
|
|
attention_mask[..., start:end, start:end] = 0
|
|
return attention_mask
|
|
|
|
def forward(
|
|
self,
|
|
input_features: torch.Tensor,
|
|
feature_lens: torch.Tensor,
|
|
aftercnn_lens: torch.Tensor,
|
|
speech_maxlen: int,
|
|
**kwargs: object,
|
|
) -> BaseModelOutput:
|
|
# For max-length audio (300s => ~7500 speech frames at 25Hz), the
|
|
# Torch SDPA path can be prohibitively memory hungry (~O(n^2) inside the
|
|
# longest chunks). Require FlashAttention for such inputs to avoid OOM
|
|
# and performance cliffs.
|
|
if int(speech_maxlen) >= 7500:
|
|
if not _has_module("flash_attn"):
|
|
raise RuntimeError(
|
|
"FunAudioChat long audio (~300s) requires FlashAttention-2 "
|
|
"for the continuous audio tower, but `flash_attn` is not "
|
|
"installed in the runtime environment."
|
|
)
|
|
if not getattr(
|
|
self.layers[0].self_attn.attn, "is_flash_attn_backend", False
|
|
):
|
|
raise RuntimeError(
|
|
"FunAudioChat long audio (~300s) requires FlashAttention for the "
|
|
"continuous audio tower, but the selected MM encoder attention "
|
|
"backend is not FlashAttention."
|
|
)
|
|
|
|
# Handle empty / invalid items (feature_lens == 0) without crashing.
|
|
original_batch_size = int(feature_lens.size(0))
|
|
device = input_features.device
|
|
|
|
valid_mask = feature_lens > 0
|
|
valid_indices = torch.where(valid_mask)[0]
|
|
|
|
if valid_indices.numel() == 0:
|
|
output_dim = int(self.proj.out_features)
|
|
return BaseModelOutput(
|
|
last_hidden_state=torch.zeros(
|
|
(original_batch_size, speech_maxlen, output_dim),
|
|
device=device,
|
|
dtype=self.proj.weight.dtype,
|
|
)
|
|
)
|
|
|
|
input_features_list = input_features.split(feature_lens.tolist(), dim=1)
|
|
valid_input_features_list = [input_features_list[int(i)] for i in valid_indices]
|
|
valid_input_features = torch.cat(valid_input_features_list, dim=1)
|
|
|
|
valid_feature_lens = feature_lens[valid_mask]
|
|
valid_aftercnn_lens = aftercnn_lens[valid_mask]
|
|
|
|
chunk_num = torch.ceil(valid_feature_lens / (self.n_window * 2)).long()
|
|
|
|
chunk_lengths_list: list[int] = []
|
|
full_chunk_len = self.n_window * 2
|
|
for i, length in enumerate(valid_feature_lens):
|
|
num_chunks_for_sample = int(chunk_num[i].item())
|
|
if num_chunks_for_sample == 0:
|
|
continue
|
|
chunk_lengths_list.extend([full_chunk_len] * (num_chunks_for_sample - 1))
|
|
last_chunk_len = int(length.item()) % full_chunk_len
|
|
if last_chunk_len == 0:
|
|
last_chunk_len = full_chunk_len
|
|
chunk_lengths_list.append(last_chunk_len)
|
|
|
|
chunk_lengths = torch.tensor(
|
|
chunk_lengths_list, dtype=torch.long, device=device
|
|
)
|
|
|
|
chunk_list = valid_input_features.split(chunk_lengths.tolist(), dim=1)
|
|
padded_feature, padded_mask, padded_mask_after_cnn = (
|
|
self.padded_and_mask_function(
|
|
chunk_list, chunk_lengths, padding_value=0, padding_side="right"
|
|
)
|
|
)
|
|
|
|
padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask
|
|
padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2)
|
|
|
|
padded_embed = padded_embed + self.positional_embedding.positional_embedding[
|
|
: padded_embed.shape[1], :
|
|
].unsqueeze(0).to(padded_embed.dtype)
|
|
|
|
hidden_states = padded_embed[padded_mask_after_cnn]
|
|
cu_seqlens = torch.cat(
|
|
(
|
|
torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32),
|
|
padded_mask_after_cnn.sum(1).cumsum(0),
|
|
)
|
|
).to(torch.int32)
|
|
|
|
for encoder_layer in self.layers:
|
|
(hidden_states,) = encoder_layer(
|
|
hidden_states,
|
|
cu_seqlens=cu_seqlens,
|
|
**kwargs,
|
|
)
|
|
|
|
hidden_states_list = hidden_states.split(valid_aftercnn_lens.tolist(), dim=0)
|
|
|
|
pooled_list: list[torch.Tensor] = []
|
|
pooled_lengths: list[int] = []
|
|
for each_audio_states in hidden_states_list:
|
|
seq_len = int(each_audio_states.shape[0])
|
|
if seq_len >= 2:
|
|
pooled = nn.functional.avg_pool1d(
|
|
each_audio_states.transpose(0, 1), kernel_size=2, stride=2
|
|
).transpose(0, 1)
|
|
else:
|
|
pooled = each_audio_states
|
|
pooled_list.append(pooled)
|
|
pooled_lengths.append(int(pooled.shape[0]))
|
|
|
|
pooled_concat = torch.cat(pooled_list, dim=0)
|
|
processed_concat = self.proj(self.ln_post(pooled_concat))
|
|
processed_audio_list = list(processed_concat.split(pooled_lengths, dim=0))
|
|
|
|
output_dim = (
|
|
int(processed_audio_list[0].shape[-1])
|
|
if processed_audio_list
|
|
else int(self.proj.out_features)
|
|
)
|
|
output_hidden_states = torch.zeros(
|
|
(original_batch_size, speech_maxlen, output_dim),
|
|
dtype=processed_audio_list[0].dtype
|
|
if processed_audio_list
|
|
else self.proj.weight.dtype,
|
|
device=device,
|
|
)
|
|
|
|
for valid_idx, processed in zip(valid_indices, processed_audio_list):
|
|
seq_len = min(int(processed.shape[0]), int(speech_maxlen))
|
|
output_hidden_states[int(valid_idx), :seq_len] = processed[:seq_len]
|
|
|
|
return BaseModelOutput(last_hidden_state=output_hidden_states)
|
|
|
|
def padded_and_mask_function(
|
|
self,
|
|
tensor_list: Sequence[torch.Tensor],
|
|
tensor_len: torch.Tensor,
|
|
padding_value: float = 0.0,
|
|
padding_side: str = "right",
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
max_len = int(tensor_len.max().item())
|
|
dim = int(tensor_list[0].shape[0])
|
|
padded_tensor = torch.full(
|
|
size=(len(tensor_list), dim, max_len),
|
|
fill_value=padding_value,
|
|
dtype=self.dtype,
|
|
device=tensor_list[0].device,
|
|
)
|
|
|
|
batch_mask = torch.zeros(
|
|
(len(tensor_len), max_len), dtype=torch.long, device=padded_tensor.device
|
|
)
|
|
for i, length in enumerate(tensor_len):
|
|
length_val = int(length.item())
|
|
batch_mask[i, :length_val] = 1
|
|
padded_tensor[i, :, :length_val] = tensor_list[i]
|
|
|
|
feature_lens_after_cnn = (tensor_len - 1) // 2 + 1
|
|
max_len_after_cnn = int(feature_lens_after_cnn.max().item())
|
|
batch_mask_after_cnn = torch.zeros(
|
|
(len(tensor_len), max_len_after_cnn),
|
|
dtype=torch.long,
|
|
device=padded_tensor.device,
|
|
)
|
|
for i, length in enumerate(feature_lens_after_cnn):
|
|
batch_mask_after_cnn[i, : int(length.item())] = 1
|
|
|
|
if padding_side != "right":
|
|
raise NotImplementedError("Only right padding is supported.")
|
|
|
|
return (
|
|
padded_tensor,
|
|
batch_mask.unsqueeze(1).to(padded_tensor.dtype),
|
|
batch_mask_after_cnn.bool(),
|
|
)
|
|
|
|
# From the HF FunAudioChat implementation.
|
|
def _get_feat_extract_output_lengths(
|
|
self, input_lengths: torch.LongTensor
|
|
) -> tuple[torch.LongTensor, torch.LongTensor]:
|
|
input_lengths = (input_lengths - 1) // 2 + 1
|
|
output_lengths = (input_lengths - 2) // 2 + 1
|
|
return input_lengths, output_lengths
|
|
|
|
|
|
class FunAudioChatDiscreteEncoder(nn.Module):
|
|
"""Discrete audio encoder (speech tokenizer -> grouped embeddings)."""
|
|
|
|
def __init__(self, config: Any):
|
|
super().__init__()
|
|
self.padding_idx = int(config.pad_token_id)
|
|
self.group_size = int(config.group_size)
|
|
self.hidden_size = int(config.output_dim)
|
|
self.continuous_features_mode = getattr(
|
|
config, "continuous_features_mode", "add"
|
|
)
|
|
self.embed_tokens = nn.Embedding(
|
|
int(config.codebook_size), self.hidden_size, self.padding_idx
|
|
)
|
|
self.output_matching = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
|
|
self.continual_output_matching = nn.Linear(
|
|
self.hidden_size, self.hidden_size, bias=False
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
audio_ids: torch.Tensor,
|
|
continuous_audio_features: torch.Tensor | None = None,
|
|
continuous_audio_output_lengths: torch.Tensor | None = None,
|
|
feature_exist_mask: torch.Tensor | None = None,
|
|
) -> torch.Tensor:
|
|
del continuous_audio_output_lengths
|
|
|
|
inputs_embeds = self.embed_tokens(audio_ids)
|
|
hidden_states = inputs_embeds.reshape(
|
|
inputs_embeds.shape[0], -1, self.group_size * self.hidden_size
|
|
)
|
|
hidden_states = hidden_states.reshape(
|
|
hidden_states.shape[0], -1, self.group_size, self.hidden_size
|
|
).mean(dim=2)
|
|
hidden_states = self.output_matching(hidden_states)
|
|
|
|
if continuous_audio_features is not None:
|
|
continuous_audio_features = continuous_audio_features.reshape(
|
|
continuous_audio_features.shape[0],
|
|
-1,
|
|
self.group_size,
|
|
self.hidden_size,
|
|
).mean(dim=2)
|
|
continuous_audio_hidden_states = self.continual_output_matching(
|
|
continuous_audio_features
|
|
)
|
|
|
|
if feature_exist_mask is None:
|
|
feature_exist_mask = torch.ones(
|
|
(hidden_states.shape[0],),
|
|
dtype=torch.bool,
|
|
device=hidden_states.device,
|
|
)
|
|
if self.continuous_features_mode == "add":
|
|
hidden_states[feature_exist_mask] += continuous_audio_hidden_states
|
|
else:
|
|
hidden_states[feature_exist_mask] = continuous_audio_hidden_states
|
|
|
|
return hidden_states
|
|
|
|
def _get_feat_extract_output_lengths(
|
|
self, input_lengths: torch.LongTensor
|
|
) -> tuple[torch.LongTensor, torch.LongTensor]:
|
|
output_lengths = (input_lengths + self.group_size - 1) // self.group_size
|
|
return input_lengths, output_lengths
|
|
|
|
|
|
class FunAudioChatProcessingInfo(BaseProcessingInfo):
|
|
token_fps: int = 25
|
|
|
|
@cached_property
|
|
def feature_extractor(self) -> WhisperFeatureExtractor:
|
|
return WhisperFeatureExtractor.from_pretrained(self.model_id)
|
|
|
|
@cached_property
|
|
def speech_tokenizer(self) -> PreTrainedTokenizerFast:
|
|
return PreTrainedTokenizerFast.from_pretrained(
|
|
self.model_id, subfolder="speech_tokenizer"
|
|
)
|
|
|
|
def get_feature_extractor(self) -> WhisperFeatureExtractor:
|
|
return self.feature_extractor
|
|
|
|
def get_speech_tokenizer(self) -> PreTrainedTokenizerFast:
|
|
return self.speech_tokenizer
|
|
|
|
def get_data_parser(self):
|
|
return MultiModalDataParser(
|
|
target_sr=int(self.feature_extractor.sampling_rate),
|
|
target_channels=self.get_target_channels(),
|
|
expected_hidden_size=self._get_expected_hidden_size(),
|
|
)
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
|
return {"audio": None}
|
|
|
|
def get_target_channels(self) -> int:
|
|
return 1
|
|
|
|
def get_mm_max_tokens_per_item(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
) -> Mapping[str, int] | None:
|
|
# The discrete audio encoder downsamples 25Hz frames with group_size=5,
|
|
# so for a 300s clip the max number of `<|AUDIO|>` placeholders is 1500.
|
|
cfg = self.get_hf_config()
|
|
audio_cfg = getattr(cfg, "audio_config", None)
|
|
max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
|
|
return {"audio": max_audio_tokens}
|
|
|
|
def get_audio_group_size(self) -> int:
|
|
cfg = self.get_hf_config()
|
|
audio_cfg = getattr(cfg, "audio_config", None)
|
|
return int(getattr(audio_cfg, "group_size", 5))
|
|
|
|
|
|
class FunAudioChatDummyInputsBuilder(
|
|
BaseDummyInputsBuilder[FunAudioChatProcessingInfo]
|
|
):
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
num_audios = mm_counts.get("audio", 0)
|
|
return "<|audio_bos|><|AUDIO|><|audio_eos|>" * int(num_audios)
|
|
|
|
def get_dummy_mm_data(
|
|
self,
|
|
seq_len: int,
|
|
mm_counts: Mapping[str, int],
|
|
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
|
mm_processor_kwargs: Mapping[str, object] | None = None,
|
|
) -> MultiModalDataDict:
|
|
feature_extractor = self.info.get_feature_extractor(
|
|
**(mm_processor_kwargs or {})
|
|
)
|
|
sampling_rate = int(feature_extractor.sampling_rate)
|
|
|
|
# Dummy inputs are used for profiling; construct the worst-case audio
|
|
# length that maximizes the number of encoder tokens.
|
|
cfg = self.info.get_hf_config()
|
|
audio_cfg = getattr(cfg, "audio_config", None)
|
|
max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500))
|
|
group_size = self.info.get_audio_group_size()
|
|
token_fps = int(getattr(self.info, "token_fps", 25))
|
|
target_num_frames = max(1, max_audio_tokens) * max(1, group_size)
|
|
audio_len = max(
|
|
1,
|
|
(target_num_frames * sampling_rate + token_fps - 1) // token_fps,
|
|
)
|
|
num_audios = int(mm_counts.get("audio", 0))
|
|
|
|
audio_overrides = mm_options.get("audio") if mm_options else None
|
|
return {
|
|
"audio": self._get_dummy_audios(
|
|
length=audio_len,
|
|
num_audios=num_audios,
|
|
overrides=audio_overrides,
|
|
)
|
|
}
|
|
|
|
|
|
class FunAudioChatMultiModalProcessor(
|
|
BaseMultiModalProcessor[FunAudioChatProcessingInfo]
|
|
):
|
|
def _call_hf_processor(
|
|
self,
|
|
prompt: str,
|
|
mm_data: Mapping[str, object],
|
|
mm_kwargs: Mapping[str, object],
|
|
tok_kwargs: Mapping[str, object],
|
|
) -> BatchFeature:
|
|
tokenizer = self.info.get_tokenizer()
|
|
input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)])
|
|
|
|
audios = mm_data.get("audios", [])
|
|
if not audios:
|
|
return BatchFeature({"input_ids": input_ids})
|
|
|
|
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
|
sr = int(feature_extractor.sampling_rate)
|
|
min_samples = int(getattr(feature_extractor, "n_fft", 400) or 400)
|
|
|
|
wavs: list[np.ndarray] = []
|
|
speech_strs: list[str] = []
|
|
|
|
speech_tokenizer = self.info.get_speech_tokenizer()
|
|
pad_token = speech_tokenizer.pad_token or "<|audio_pad|>"
|
|
for audio in audios:
|
|
if isinstance(audio, torch.Tensor):
|
|
audio = audio.detach().cpu().numpy()
|
|
audio_np = np.asarray(audio, dtype=np.float32)
|
|
|
|
if min_samples > 0 and audio_np.shape[0] < min_samples:
|
|
audio_np = np.pad(
|
|
audio_np, (0, min_samples - audio_np.shape[0]), mode="constant"
|
|
)
|
|
|
|
wavs.append(audio_np)
|
|
num_frames = int(
|
|
(float(audio_np.shape[0]) / float(sr)) * float(self.info.token_fps)
|
|
)
|
|
speech_strs.append(pad_token * max(1, int(num_frames)))
|
|
|
|
audio_group_size = self.info.get_audio_group_size()
|
|
speech_inputs = speech_tokenizer(
|
|
speech_strs,
|
|
return_attention_mask=True,
|
|
return_token_type_ids=False,
|
|
padding=True,
|
|
pad_to_multiple_of=audio_group_size,
|
|
return_tensors="pt",
|
|
)
|
|
|
|
wav_inputs = feature_extractor(
|
|
wavs,
|
|
sampling_rate=sr,
|
|
return_attention_mask=True,
|
|
padding="max_length",
|
|
return_tensors="pt",
|
|
)
|
|
|
|
mm_inputs: dict[str, torch.Tensor] = {
|
|
"speech_ids": speech_inputs["input_ids"],
|
|
"speech_attention_mask": speech_inputs["attention_mask"],
|
|
"input_features": wav_inputs["input_features"],
|
|
"feature_attention_mask": wav_inputs["attention_mask"],
|
|
"feature_exist_mask": torch.ones((len(wavs),), dtype=torch.bool),
|
|
}
|
|
|
|
return BatchFeature({"input_ids": input_ids, **mm_inputs})
|
|
|
|
def _hf_processor_applies_updates(
|
|
self,
|
|
prompt_text: str,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
tokenization_kwargs: Mapping[str, object],
|
|
) -> bool:
|
|
return False
|
|
|
|
def _get_mm_fields_config(
|
|
self,
|
|
hf_inputs: BatchFeature,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
return {
|
|
"speech_ids": MultiModalFieldConfig.batched("audio"),
|
|
"speech_attention_mask": MultiModalFieldConfig.batched("audio"),
|
|
"input_features": MultiModalFieldConfig.batched("audio"),
|
|
"feature_attention_mask": MultiModalFieldConfig.batched("audio"),
|
|
"feature_exist_mask": MultiModalFieldConfig.batched("audio"),
|
|
}
|
|
|
|
def _get_prompt_updates(
|
|
self,
|
|
mm_items: MultiModalDataItems,
|
|
hf_processor_mm_kwargs: Mapping[str, object],
|
|
out_mm_kwargs: MultiModalKwargsItems,
|
|
) -> Sequence[PromptUpdate]:
|
|
tokenizer = self.info.get_tokenizer()
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
audio_token = "<|AUDIO|>"
|
|
audio_token_id = vocab[audio_token]
|
|
|
|
out_mm_data = out_mm_kwargs.get_data()
|
|
speech_attention_mask = out_mm_data.get("speech_attention_mask")
|
|
if speech_attention_mask is None:
|
|
audio_output_lengths: list[int] = []
|
|
else:
|
|
assert isinstance(speech_attention_mask, torch.Tensor)
|
|
speech_lengths = speech_attention_mask.sum(-1)
|
|
group_size = self.info.get_audio_group_size()
|
|
audio_output_lengths = (
|
|
(speech_lengths + group_size - 1) // group_size
|
|
).tolist()
|
|
|
|
def get_replacement_funaudiochat(item_idx: int):
|
|
num_features = (
|
|
int(audio_output_lengths[item_idx]) if audio_output_lengths else 1
|
|
)
|
|
if num_features <= 0:
|
|
audios = mm_items.get_items("audio", AudioProcessorItems)
|
|
audio_len = audios.get_audio_length(item_idx)
|
|
raise ValueError(
|
|
f"The audio (len={audio_len}) is too short to be "
|
|
"represented inside the model"
|
|
)
|
|
|
|
audio_tokens = [audio_token_id] * num_features
|
|
return PromptUpdateDetails.select_token_id(
|
|
audio_tokens,
|
|
embed_token_id=audio_token_id,
|
|
)
|
|
|
|
return [
|
|
PromptReplacement(
|
|
modality="audio",
|
|
target=audio_token,
|
|
replacement=get_replacement_funaudiochat,
|
|
)
|
|
]
|
|
|
|
|
|
@MULTIMODAL_REGISTRY.register_processor(
|
|
FunAudioChatMultiModalProcessor,
|
|
info=FunAudioChatProcessingInfo,
|
|
dummy_inputs=FunAudioChatDummyInputsBuilder,
|
|
)
|
|
class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
|
@classmethod
|
|
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
|
if modality.startswith("audio"):
|
|
return "<|audio_bos|><|AUDIO|><|audio_eos|>"
|
|
|
|
raise ValueError("Only audio modality is supported")
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
multimodal_config = vllm_config.model_config.multimodal_config
|
|
self.config = config
|
|
self.multimodal_config = multimodal_config
|
|
self.quant_config = quant_config
|
|
|
|
with self._mark_tower_model(vllm_config, "audio"):
|
|
self.continuous_audio_tower = FunAudioChatAudioEncoder(config.audio_config)
|
|
self.audio_tower = FunAudioChatDiscreteEncoder(config.audio_config)
|
|
|
|
with self._mark_language_model(vllm_config):
|
|
self.language_model = init_vllm_registered_model(
|
|
vllm_config=vllm_config,
|
|
hf_config=config.text_config,
|
|
prefix=maybe_prefix(prefix, "language_model"),
|
|
architectures=["Qwen3ForCausalLM"],
|
|
)
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
self.language_model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def _get_continuous_audio_features(
|
|
self,
|
|
input_features: torch.Tensor,
|
|
feature_attention_mask: torch.Tensor,
|
|
speech_maxlen: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Align mask and features to avoid indexing errors when padding differs.
|
|
if (
|
|
input_features.dim() == 3
|
|
and feature_attention_mask.shape[1] != input_features.shape[-1]
|
|
):
|
|
min_len = min(
|
|
int(feature_attention_mask.shape[1]), int(input_features.shape[-1])
|
|
)
|
|
feature_attention_mask = feature_attention_mask[:, :min_len]
|
|
input_features = input_features[:, :, :min_len]
|
|
|
|
feature_lens = torch.sum(feature_attention_mask, dim=1)
|
|
|
|
flat_features = input_features.permute(0, 2, 1)[
|
|
feature_attention_mask.bool()
|
|
].permute(1, 0)
|
|
|
|
audio_feat_lengths, audio_output_lengths = (
|
|
self.continuous_audio_tower._get_feat_extract_output_lengths(feature_lens)
|
|
)
|
|
|
|
audio_outputs = self.continuous_audio_tower(
|
|
flat_features,
|
|
feature_lens=feature_lens,
|
|
aftercnn_lens=audio_feat_lengths,
|
|
speech_maxlen=speech_maxlen,
|
|
)
|
|
return audio_outputs.last_hidden_state, audio_output_lengths
|
|
|
|
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
|
speech_ids = kwargs.get("speech_ids")
|
|
speech_attention_mask = kwargs.get("speech_attention_mask")
|
|
input_features = kwargs.get("input_features")
|
|
feature_attention_mask = kwargs.get("feature_attention_mask")
|
|
feature_exist_mask = kwargs.get("feature_exist_mask")
|
|
|
|
if speech_ids is None:
|
|
return []
|
|
|
|
pad_id = int(getattr(self.audio_tower, "padding_idx", 0))
|
|
|
|
if not isinstance(speech_ids, torch.Tensor):
|
|
if (
|
|
isinstance(speech_ids, (list, tuple))
|
|
and len(speech_ids) > 0
|
|
and all(isinstance(t, torch.Tensor) for t in speech_ids)
|
|
):
|
|
speech_ids_tensors = []
|
|
for t in speech_ids:
|
|
if t.dim() == 2 and t.shape[0] == 1:
|
|
t = t.squeeze(0)
|
|
if t.dim() != 1:
|
|
raise TypeError(
|
|
"FunAudioChat speech_ids must be a 1D tensor per item "
|
|
f"(got shape={tuple(t.shape)})"
|
|
)
|
|
speech_ids_tensors.append(t)
|
|
speech_ids = nn.utils.rnn.pad_sequence(
|
|
speech_ids_tensors,
|
|
batch_first=True,
|
|
padding_value=pad_id,
|
|
)
|
|
else:
|
|
raise TypeError(
|
|
"FunAudioChat speech_ids must be a Tensor or a sequence of Tensors "
|
|
f"(got {type(speech_ids)})"
|
|
)
|
|
|
|
if speech_attention_mask is None:
|
|
speech_attention_mask = speech_ids.ne(pad_id).to(dtype=torch.int64)
|
|
|
|
if not isinstance(speech_attention_mask, torch.Tensor):
|
|
if (
|
|
isinstance(speech_attention_mask, (list, tuple))
|
|
and len(speech_attention_mask) > 0
|
|
and all(isinstance(t, torch.Tensor) for t in speech_attention_mask)
|
|
):
|
|
mask_tensors = []
|
|
for t in speech_attention_mask:
|
|
if t.dim() == 2 and t.shape[0] == 1:
|
|
t = t.squeeze(0)
|
|
if t.dim() != 1:
|
|
raise TypeError(
|
|
"FunAudioChat speech_attention_mask must be a 1D tensor "
|
|
f"per item (got shape={tuple(t.shape)})"
|
|
)
|
|
mask_tensors.append(t)
|
|
speech_attention_mask = nn.utils.rnn.pad_sequence(
|
|
mask_tensors,
|
|
batch_first=True,
|
|
padding_value=0,
|
|
)
|
|
else:
|
|
raise TypeError(
|
|
"FunAudioChat speech_attention_mask must be a Tensor or a "
|
|
f"sequence of Tensors (got {type(speech_attention_mask)})"
|
|
)
|
|
|
|
debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1"
|
|
if debug:
|
|
print(
|
|
f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} "
|
|
f"speech_attention_mask={tuple(speech_attention_mask.shape)}",
|
|
flush=True,
|
|
)
|
|
attn_impl = getattr(
|
|
self.continuous_audio_tower.config, "_attn_implementation", None
|
|
)
|
|
print(
|
|
f"[FunAudioChat] audio_attn_impl={attn_impl}",
|
|
flush=True,
|
|
)
|
|
if hasattr(self.continuous_audio_tower, "conv1"):
|
|
conv1_w = self.continuous_audio_tower.conv1.weight
|
|
print(
|
|
f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}",
|
|
flush=True,
|
|
)
|
|
try:
|
|
attn0 = self.continuous_audio_tower.layers[0].self_attn
|
|
q_norm = float(attn0.q_proj.weight.norm().item())
|
|
k_norm = float(attn0.k_proj.weight.norm().item())
|
|
v_norm = float(attn0.v_proj.weight.norm().item())
|
|
o_norm = float(attn0.out_proj.weight.norm().item())
|
|
print(
|
|
f"[FunAudioChat] attn0_q_norm={q_norm:.6g} "
|
|
f"k_norm={k_norm:.6g} "
|
|
f"v_norm={v_norm:.6g} "
|
|
f"o_norm={o_norm:.6g}",
|
|
flush=True,
|
|
)
|
|
except Exception:
|
|
pass
|
|
if isinstance(input_features, torch.Tensor):
|
|
print(
|
|
f"[FunAudioChat] input_features={tuple(input_features.shape)}",
|
|
flush=True,
|
|
)
|
|
if isinstance(feature_attention_mask, torch.Tensor):
|
|
print(
|
|
"[FunAudioChat] feature_attention_mask="
|
|
f"{tuple(feature_attention_mask.shape)}",
|
|
flush=True,
|
|
)
|
|
|
|
group_size = int(self.audio_tower.group_size)
|
|
speech_maxlen = int(speech_ids.shape[-1])
|
|
|
|
# Ensure token length is divisible by group_size.
|
|
target_len = ((speech_maxlen + group_size - 1) // group_size) * group_size
|
|
if target_len > speech_maxlen:
|
|
pad_id = int(self.audio_tower.padding_idx)
|
|
pad_len = target_len - speech_maxlen
|
|
speech_ids = nn.functional.pad(speech_ids, (0, pad_len), value=pad_id)
|
|
speech_attention_mask = nn.functional.pad(
|
|
speech_attention_mask, (0, pad_len), value=0
|
|
)
|
|
speech_maxlen = int(speech_ids.shape[-1])
|
|
|
|
continuous_audio_features = None
|
|
continuous_audio_output_lengths = None
|
|
if input_features is not None and feature_attention_mask is not None:
|
|
assert isinstance(input_features, torch.Tensor)
|
|
assert isinstance(feature_attention_mask, torch.Tensor)
|
|
continuous_audio_features, continuous_audio_output_lengths = (
|
|
self._get_continuous_audio_features(
|
|
input_features=input_features,
|
|
feature_attention_mask=feature_attention_mask,
|
|
speech_maxlen=speech_maxlen,
|
|
)
|
|
)
|
|
|
|
if feature_exist_mask is None:
|
|
feature_exist_mask = torch.ones(
|
|
(speech_ids.shape[0],), dtype=torch.bool, device=speech_ids.device
|
|
)
|
|
assert isinstance(feature_exist_mask, torch.Tensor)
|
|
|
|
audio_features = self.audio_tower(
|
|
speech_ids,
|
|
continuous_audio_features=continuous_audio_features,
|
|
continuous_audio_output_lengths=continuous_audio_output_lengths,
|
|
feature_exist_mask=feature_exist_mask,
|
|
)
|
|
|
|
_, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths(
|
|
speech_attention_mask.sum(-1)
|
|
)
|
|
lengths = audio_output_lengths.tolist()
|
|
|
|
embeds = tuple(
|
|
audio_features[i, : int(length)] for i, length in enumerate(lengths)
|
|
)
|
|
if debug:
|
|
embed_lens = [int(t.shape[0]) for t in embeds]
|
|
print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True)
|
|
if embeds:
|
|
t0 = embeds[0]
|
|
print(
|
|
f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} "
|
|
f"nan={bool(torch.isnan(t0).any())} "
|
|
f"norm={float(t0.norm().item()):.6g}",
|
|
flush=True,
|
|
)
|
|
dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "")
|
|
if (
|
|
dump_path
|
|
and speech_ids.shape[0] == 1
|
|
and len(embeds) == 1
|
|
and embed_lens[0] > 10
|
|
):
|
|
if not os.path.exists(dump_path):
|
|
np.save(dump_path, embeds[0].detach().float().cpu().numpy())
|
|
print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True)
|
|
cont_path = dump_path.replace(".npy", "_cont.npy")
|
|
if continuous_audio_features is not None and not os.path.exists(
|
|
cont_path
|
|
):
|
|
np.save(
|
|
cont_path,
|
|
continuous_audio_features.detach().float().cpu().numpy(),
|
|
)
|
|
print(
|
|
f"[FunAudioChat] dumped continuous to {cont_path}", flush=True
|
|
)
|
|
return embeds
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
**kwargs: object,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
del kwargs
|
|
if intermediate_tensors is not None:
|
|
inputs_embeds = None
|
|
|
|
return self.language_model.model(
|
|
input_ids,
|
|
positions,
|
|
intermediate_tensors,
|
|
inputs_embeds=inputs_embeds,
|
|
)
|
|
|
|
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None:
|
|
return self.language_model.compute_logits(hidden_states)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self, skip_prefixes=["audio_invert_tower."])
|
|
return loader.load_weights(weights)
|