diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index f7106d7ea..62847778b 100755 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -117,6 +117,31 @@ def run_glmasr(question: str, audio_count: int) -> ModelRequestData: ) +# FunAudioChat +def run_funaudiochat(question: str, audio_count: int) -> ModelRequestData: + # NOTE: FunAudioChat is not available on the HuggingFace Hub at the time of + # writing. Pass a local model path via `--model`. + model_name = "funaudiochat" + + engine_args = EngineArgs( + model=model_name, + max_model_len=4096, + max_num_seqs=2, + limit_mm_per_prompt={"audio": audio_count}, + enforce_eager=True, + ) + + audio_in_prompt = "".join( + ["<|audio_bos|><|AUDIO|><|audio_eos|>\n" for _ in range(audio_count)] + ) + prompt = f"{audio_in_prompt}{question}" + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # Granite Speech def run_granite_speech(question: str, audio_count: int) -> ModelRequestData: # NOTE - the setting in this example are somewhat different from what is @@ -410,6 +435,7 @@ model_example_map = { "audioflamingo3": run_audioflamingo3, "gemma3n": run_gemma3n, "glmasr": run_glmasr, + "funaudiochat": run_funaudiochat, "granite_speech": run_granite_speech, "midashenglm": run_midashenglm, "minicpmo": run_minicpmo, @@ -435,6 +461,12 @@ def parse_args(): choices=model_example_map.keys(), help='Huggingface "model_type".', ) + parser.add_argument( + "--model", + type=str, + default=None, + help="Model ID or local path override. Required for funaudiochat.", + ) parser.add_argument( "--num-prompts", type=int, default=1, help="Number of prompts to run." ) @@ -467,6 +499,9 @@ def main(args): if model not in model_example_map: raise ValueError(f"Model type {model} is not supported.") + if model == "funaudiochat" and not args.model: + raise ValueError("--model is required when --model-type=funaudiochat") + if args.tensor_parallel_size is not None and args.tensor_parallel_size < 1: raise ValueError( f"tensor_parallel_size must be a positive integer, " @@ -477,6 +512,8 @@ def main(args): req_data = model_example_map[model]( question_per_audio_count[audio_count], audio_count ) + if model == "funaudiochat": + req_data.engine_args.model = args.model # Disable other modalities to save memory default_limits = {"image": 0, "video": 0, "audio": 0} diff --git a/tests/models/registry.py b/tests/models/registry.py index fd6e4ecb1..8fd801dc6 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -692,6 +692,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { "baidu/ERNIE-4.5-VL-28B-A3B-PT", trust_remote_code=True, ), + "FunAudioChatForConditionalGeneration": _HfExamplesInfo( + "funaudiochat", is_available_online=False + ), "FuyuForCausalLM": _HfExamplesInfo("adept/fuyu-8b"), "Gemma3ForConditionalGeneration": _HfExamplesInfo("google/gemma-3-4b-it"), "Gemma3nForConditionalGeneration": _HfExamplesInfo("google/gemma-3n-E2B-it"), diff --git a/vllm/model_executor/models/funaudiochat.py b/vllm/model_executor/models/funaudiochat.py new file mode 100644 index 000000000..995fb6944 --- /dev/null +++ b/vllm/model_executor/models/funaudiochat.py @@ -0,0 +1,1083 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Inference-only FunAudioChat model compatible with HuggingFace weights. + +FunAudioChat is a Qwen3 text model augmented with: + - a continuous audio encoder (Whisper-mel frontend + transformer) + - a discrete audio encoder (speech tokenizer + projector) + +In the HF implementation, audio features are scattered into `<|AUDIO|>` token +positions via `inputs_embeds`, while `position_ids` (RoPE) remains standard 1D. +""" + +from __future__ import annotations + +import os +from collections.abc import Iterable, Mapping, Sequence +from functools import cached_property +from typing import Any + +import numpy as np +import torch +import torch.nn as nn +from transformers import PreTrainedTokenizerFast, WhisperFeatureExtractor +from transformers.activations import get_activation +from transformers.feature_extraction_utils import BatchFeature +from transformers.modeling_outputs import BaseModelOutput + +from vllm.config import VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.model_executor.layers.attention.mm_encoder_attention import MMEncoderAttention +from vllm.model_executor.layers.linear import QKVParallelLinear, RowParallelLinear +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.sequence import IntermediateTensors +from vllm.utils.import_utils import _has_module + +from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP +from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix + + +class _SinusoidsPositionEmbedding(nn.Module): + def __init__(self, length: int, channels: int, max_timescale: float = 10000.0): + super().__init__() + if channels % 2 != 0: + raise ValueError("SinusoidsPositionEmbedding needs even channels input") + + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp( + -log_timescale_increment * torch.arange(channels // 2).float() + ) + scaled_time = ( + torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + ) + self.register_buffer( + "positional_embedding", + torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1), + persistent=False, + ) + + +class FunAudioChatAudioAttention(nn.Module): + """Multi-headed attention used inside the continuous audio tower.""" + + def __init__(self, config: Any): + super().__init__() + self.embed_dim = int(config.d_model) + self.total_num_heads = int(config.encoder_attention_heads) + self.dropout = float(getattr(config, "attention_dropout", 0.0)) + self.head_dim = self.embed_dim // self.total_num_heads + self.num_key_value_groups = 1 # needed for eager attention + self.config = config + + if self.head_dim * self.total_num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got embed_dim={self.embed_dim}, " + f"num_heads={self.total_num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.attention_dropout = 0.0 + self.is_decoder = False + self.is_causal = False + + self.qkv_proj = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.total_num_heads, + bias=True, + ) + self.num_heads = self.qkv_proj.num_heads + self.num_kv_heads = self.qkv_proj.num_kv_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.attn = MMEncoderAttention( + num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + prefix="funaudiochat_audio_tower.attn", + ) + self.out_proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + bias=True, + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + params_dict = dict(self.named_parameters()) + with torch.no_grad(): + if self.qkv_proj.bias is not None: + # HF FunAudioChat uses bias=False for k_proj. Ensure the missing + # shard starts as zeros, while allowing q/v shards to load. + self.qkv_proj.bias.zero_() + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, shard_name, shard_id in stacked_params_mapping: + if shard_name not in name: + continue + name = name.replace(shard_name, param_name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + + return loaded_params + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor: + del kwargs + del attention_mask + seq_length, _ = hidden_states.size() + + qkv, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv.split( + [self.q_size, self.kv_size, self.kv_size], dim=-1 + ) + + max_seqlen: torch.Tensor | None = None + if cu_seqlens is not None: + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max() + + attn_output = self.attn( + query_states.reshape(1, seq_length, self.q_size), + key_states.reshape(1, seq_length, self.kv_size), + value_states.reshape(1, seq_length, self.kv_size), + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ).reshape(seq_length, -1) + + output, _ = self.out_proj(attn_output) + return output + + +class FunAudioChatAudioEncoderLayer(nn.Module): + def __init__(self, config: Any): + super().__init__() + self.embed_dim = int(config.d_model) + self.self_attn = FunAudioChatAudioAttention(config) + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = float(config.dropout) + self.activation_fn = get_activation(str(config.activation_function)) + self.activation_dropout = float(config.activation_dropout) + self.fc1 = nn.Linear(self.embed_dim, int(config.encoder_ffn_dim)) + self.fc2 = nn.Linear(int(config.encoder_ffn_dim), self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + attention_mask: torch.Tensor | None = None, + **kwargs: object, + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states = self.self_attn( + hidden_states=hidden_states, + cu_seqlens=cu_seqlens, + attention_mask=attention_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout( + hidden_states, p=self.activation_dropout, training=self.training + ) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout( + hidden_states, p=self.dropout, training=self.training + ) + hidden_states = residual + hidden_states + + return (hidden_states,) + + +class FunAudioChatAudioEncoder(nn.Module): + """Continuous audio tower.""" + + def __init__(self, config: Any): + super().__init__() + self.config = config + + embed_dim = int(config.d_model) + self.num_mel_bins = int(config.num_mel_bins) + self.max_source_positions = int(config.max_source_positions) + self.embed_scale = (embed_dim**0.5) if bool(config.scale_embedding) else 1.0 + self.n_window = int(config.n_window) + + self.conv1 = nn.Conv1d(self.num_mel_bins, embed_dim, kernel_size=3, padding=1) + self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) + self.layers = nn.ModuleList( + [ + FunAudioChatAudioEncoderLayer(config) + for _ in range(int(config.encoder_layers)) + ] + ) + self.ln_post = nn.LayerNorm(embed_dim) + self.avg_pooler = nn.AvgPool1d(2, stride=2) + self.proj = nn.Linear(embed_dim, int(config.output_dim)) + self.positional_embedding = _SinusoidsPositionEmbedding( + self.max_source_positions, embed_dim + ) + + # Present in HF weights even if unused during S2T. + self.audio_bos_eos_token = nn.Embedding(2, int(config.output_dim)) + + @property + def dtype(self) -> torch.dtype: + return self.conv1.weight.dtype + + def _prepare_attention_mask( + self, inputs_tensor: torch.Tensor, cu_seqlens: torch.Tensor + ) -> torch.Tensor | None: + if getattr(self.config, "_attn_implementation", "eager") == "flash_attention_2": + return None + + seq_length = inputs_tensor.shape[0] + attention_mask = torch.full( + (1, 1, seq_length, seq_length), + torch.finfo(inputs_tensor.dtype).min, + device=inputs_tensor.device, + dtype=inputs_tensor.dtype, + ) + for i in range(1, len(cu_seqlens)): + start = int(cu_seqlens[i - 1].item()) + end = int(cu_seqlens[i].item()) + attention_mask[..., start:end, start:end] = 0 + return attention_mask + + def forward( + self, + input_features: torch.Tensor, + feature_lens: torch.Tensor, + aftercnn_lens: torch.Tensor, + speech_maxlen: int, + **kwargs: object, + ) -> BaseModelOutput: + # For max-length audio (300s => ~7500 speech frames at 25Hz), the + # Torch SDPA path can be prohibitively memory hungry (~O(n^2) inside the + # longest chunks). Require FlashAttention for such inputs to avoid OOM + # and performance cliffs. + if int(speech_maxlen) >= 7500: + if not _has_module("flash_attn"): + raise RuntimeError( + "FunAudioChat long audio (~300s) requires FlashAttention-2 " + "for the continuous audio tower, but `flash_attn` is not " + "installed in the runtime environment." + ) + if not getattr( + self.layers[0].self_attn.attn, "is_flash_attn_backend", False + ): + raise RuntimeError( + "FunAudioChat long audio (~300s) requires FlashAttention for the " + "continuous audio tower, but the selected MM encoder attention " + "backend is not FlashAttention." + ) + + # Handle empty / invalid items (feature_lens == 0) without crashing. + original_batch_size = int(feature_lens.size(0)) + device = input_features.device + + valid_mask = feature_lens > 0 + valid_indices = torch.where(valid_mask)[0] + + if valid_indices.numel() == 0: + output_dim = int(self.proj.out_features) + return BaseModelOutput( + last_hidden_state=torch.zeros( + (original_batch_size, speech_maxlen, output_dim), + device=device, + dtype=self.proj.weight.dtype, + ) + ) + + input_features_list = input_features.split(feature_lens.tolist(), dim=1) + valid_input_features_list = [input_features_list[int(i)] for i in valid_indices] + valid_input_features = torch.cat(valid_input_features_list, dim=1) + + valid_feature_lens = feature_lens[valid_mask] + valid_aftercnn_lens = aftercnn_lens[valid_mask] + + chunk_num = torch.ceil(valid_feature_lens / (self.n_window * 2)).long() + + chunk_lengths_list: list[int] = [] + full_chunk_len = self.n_window * 2 + for i, length in enumerate(valid_feature_lens): + num_chunks_for_sample = int(chunk_num[i].item()) + if num_chunks_for_sample == 0: + continue + chunk_lengths_list.extend([full_chunk_len] * (num_chunks_for_sample - 1)) + last_chunk_len = int(length.item()) % full_chunk_len + if last_chunk_len == 0: + last_chunk_len = full_chunk_len + chunk_lengths_list.append(last_chunk_len) + + chunk_lengths = torch.tensor( + chunk_lengths_list, dtype=torch.long, device=device + ) + + chunk_list = valid_input_features.split(chunk_lengths.tolist(), dim=1) + padded_feature, padded_mask, padded_mask_after_cnn = ( + self.padded_and_mask_function( + chunk_list, chunk_lengths, padding_value=0, padding_side="right" + ) + ) + + padded_embed = nn.functional.gelu(self.conv1(padded_feature)) * padded_mask + padded_embed = nn.functional.gelu(self.conv2(padded_embed)).transpose(1, 2) + + padded_embed = padded_embed + self.positional_embedding.positional_embedding[ + : padded_embed.shape[1], : + ].unsqueeze(0).to(padded_embed.dtype) + + hidden_states = padded_embed[padded_mask_after_cnn] + cu_seqlens = torch.cat( + ( + torch.zeros(1, device=padded_mask_after_cnn.device, dtype=torch.int32), + padded_mask_after_cnn.sum(1).cumsum(0), + ) + ).to(torch.int32) + + for encoder_layer in self.layers: + (hidden_states,) = encoder_layer( + hidden_states, + cu_seqlens=cu_seqlens, + **kwargs, + ) + + hidden_states_list = hidden_states.split(valid_aftercnn_lens.tolist(), dim=0) + + pooled_list: list[torch.Tensor] = [] + pooled_lengths: list[int] = [] + for each_audio_states in hidden_states_list: + seq_len = int(each_audio_states.shape[0]) + if seq_len >= 2: + pooled = nn.functional.avg_pool1d( + each_audio_states.transpose(0, 1), kernel_size=2, stride=2 + ).transpose(0, 1) + else: + pooled = each_audio_states + pooled_list.append(pooled) + pooled_lengths.append(int(pooled.shape[0])) + + pooled_concat = torch.cat(pooled_list, dim=0) + processed_concat = self.proj(self.ln_post(pooled_concat)) + processed_audio_list = list(processed_concat.split(pooled_lengths, dim=0)) + + output_dim = ( + int(processed_audio_list[0].shape[-1]) + if processed_audio_list + else int(self.proj.out_features) + ) + output_hidden_states = torch.zeros( + (original_batch_size, speech_maxlen, output_dim), + dtype=processed_audio_list[0].dtype + if processed_audio_list + else self.proj.weight.dtype, + device=device, + ) + + for valid_idx, processed in zip(valid_indices, processed_audio_list): + seq_len = min(int(processed.shape[0]), int(speech_maxlen)) + output_hidden_states[int(valid_idx), :seq_len] = processed[:seq_len] + + return BaseModelOutput(last_hidden_state=output_hidden_states) + + def padded_and_mask_function( + self, + tensor_list: Sequence[torch.Tensor], + tensor_len: torch.Tensor, + padding_value: float = 0.0, + padding_side: str = "right", + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + max_len = int(tensor_len.max().item()) + dim = int(tensor_list[0].shape[0]) + padded_tensor = torch.full( + size=(len(tensor_list), dim, max_len), + fill_value=padding_value, + dtype=self.dtype, + device=tensor_list[0].device, + ) + + batch_mask = torch.zeros( + (len(tensor_len), max_len), dtype=torch.long, device=padded_tensor.device + ) + for i, length in enumerate(tensor_len): + length_val = int(length.item()) + batch_mask[i, :length_val] = 1 + padded_tensor[i, :, :length_val] = tensor_list[i] + + feature_lens_after_cnn = (tensor_len - 1) // 2 + 1 + max_len_after_cnn = int(feature_lens_after_cnn.max().item()) + batch_mask_after_cnn = torch.zeros( + (len(tensor_len), max_len_after_cnn), + dtype=torch.long, + device=padded_tensor.device, + ) + for i, length in enumerate(feature_lens_after_cnn): + batch_mask_after_cnn[i, : int(length.item())] = 1 + + if padding_side != "right": + raise NotImplementedError("Only right padding is supported.") + + return ( + padded_tensor, + batch_mask.unsqueeze(1).to(padded_tensor.dtype), + batch_mask_after_cnn.bool(), + ) + + # From the HF FunAudioChat implementation. + def _get_feat_extract_output_lengths( + self, input_lengths: torch.LongTensor + ) -> tuple[torch.LongTensor, torch.LongTensor]: + input_lengths = (input_lengths - 1) // 2 + 1 + output_lengths = (input_lengths - 2) // 2 + 1 + return input_lengths, output_lengths + + +class FunAudioChatDiscreteEncoder(nn.Module): + """Discrete audio encoder (speech tokenizer -> grouped embeddings).""" + + def __init__(self, config: Any): + super().__init__() + self.padding_idx = int(config.pad_token_id) + self.group_size = int(config.group_size) + self.hidden_size = int(config.output_dim) + self.continuous_features_mode = getattr( + config, "continuous_features_mode", "add" + ) + self.embed_tokens = nn.Embedding( + int(config.codebook_size), self.hidden_size, self.padding_idx + ) + self.output_matching = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + self.continual_output_matching = nn.Linear( + self.hidden_size, self.hidden_size, bias=False + ) + + def forward( + self, + audio_ids: torch.Tensor, + continuous_audio_features: torch.Tensor | None = None, + continuous_audio_output_lengths: torch.Tensor | None = None, + feature_exist_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + del continuous_audio_output_lengths + + inputs_embeds = self.embed_tokens(audio_ids) + hidden_states = inputs_embeds.reshape( + inputs_embeds.shape[0], -1, self.group_size * self.hidden_size + ) + hidden_states = hidden_states.reshape( + hidden_states.shape[0], -1, self.group_size, self.hidden_size + ).mean(dim=2) + hidden_states = self.output_matching(hidden_states) + + if continuous_audio_features is not None: + continuous_audio_features = continuous_audio_features.reshape( + continuous_audio_features.shape[0], + -1, + self.group_size, + self.hidden_size, + ).mean(dim=2) + continuous_audio_hidden_states = self.continual_output_matching( + continuous_audio_features + ) + + if feature_exist_mask is None: + feature_exist_mask = torch.ones( + (hidden_states.shape[0],), + dtype=torch.bool, + device=hidden_states.device, + ) + if self.continuous_features_mode == "add": + hidden_states[feature_exist_mask] += continuous_audio_hidden_states + else: + hidden_states[feature_exist_mask] = continuous_audio_hidden_states + + return hidden_states + + def _get_feat_extract_output_lengths( + self, input_lengths: torch.LongTensor + ) -> tuple[torch.LongTensor, torch.LongTensor]: + output_lengths = (input_lengths + self.group_size - 1) // self.group_size + return input_lengths, output_lengths + + +class FunAudioChatProcessingInfo(BaseProcessingInfo): + token_fps: int = 25 + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": None} + + def get_target_channels(self) -> int: + return 1 + + def get_mm_max_tokens_per_item( + self, + seq_len: int, + mm_counts: Mapping[str, int], + ) -> Mapping[str, int] | None: + # The discrete audio encoder downsamples 25Hz frames with group_size=5, + # so for a 300s clip the max number of `<|AUDIO|>` placeholders is 1500. + cfg = self.get_hf_config() + audio_cfg = getattr(cfg, "audio_config", None) + max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500)) + return {"audio": max_audio_tokens} + + @cached_property + def feature_extractor(self) -> WhisperFeatureExtractor: + return WhisperFeatureExtractor.from_pretrained(self.model_id) + + @cached_property + def speech_tokenizer(self) -> PreTrainedTokenizerFast: + return PreTrainedTokenizerFast.from_pretrained( + self.model_id, subfolder="speech_tokenizer" + ) + + def get_feature_extractor(self) -> WhisperFeatureExtractor: + return self.feature_extractor + + def get_speech_tokenizer(self) -> PreTrainedTokenizerFast: + return self.speech_tokenizer + + def get_audio_group_size(self) -> int: + cfg = self.get_hf_config() + audio_cfg = getattr(cfg, "audio_config", None) + return int(getattr(audio_cfg, "group_size", 5)) + + +class FunAudioChatDummyInputsBuilder( + BaseDummyInputsBuilder[FunAudioChatProcessingInfo] +): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + return "<|audio_bos|><|AUDIO|><|audio_eos|>" * int(num_audios) + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + sampling_rate = int(feature_extractor.sampling_rate) + + # Dummy inputs are used for profiling; construct the worst-case audio + # length that maximizes the number of encoder tokens. + cfg = self.info.get_hf_config() + audio_cfg = getattr(cfg, "audio_config", None) + max_audio_tokens = int(getattr(audio_cfg, "max_source_positions", 1500)) + group_size = self.info.get_audio_group_size() + token_fps = int(getattr(self.info, "token_fps", 25)) + target_num_frames = max(1, max_audio_tokens) * max(1, group_size) + audio_len = max( + 1, + (target_num_frames * sampling_rate + token_fps - 1) // token_fps, + ) + num_audios = int(mm_counts.get("audio", 0)) + + audio_overrides = mm_options.get("audio") if mm_options else None + return { + "audio": self._get_dummy_audios( + length=audio_len, + num_audios=num_audios, + overrides=audio_overrides, + ) + } + + +class FunAudioChatMultiModalProcessor( + BaseMultiModalProcessor[FunAudioChatProcessingInfo] +): + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser( + target_sr=int(feature_extractor.sampling_rate), + target_channels=self.info.get_target_channels(), + ) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + tokenizer = self.info.get_tokenizer() + input_ids = torch.tensor([tokenizer.encode(prompt, **tok_kwargs)]) + + audios = mm_data.get("audios", []) + if not audios: + return BatchFeature({"input_ids": input_ids}) + + feature_extractor = self.info.get_feature_extractor() + sr = int(feature_extractor.sampling_rate) + min_samples = int(getattr(feature_extractor, "n_fft", 400) or 400) + + wavs: list[np.ndarray] = [] + speech_strs: list[str] = [] + + speech_tokenizer = self.info.get_speech_tokenizer() + pad_token = speech_tokenizer.pad_token or "<|audio_pad|>" + for audio in audios: + if isinstance(audio, torch.Tensor): + audio = audio.detach().cpu().numpy() + audio_np = np.asarray(audio, dtype=np.float32) + + if min_samples > 0 and audio_np.shape[0] < min_samples: + audio_np = np.pad( + audio_np, (0, min_samples - audio_np.shape[0]), mode="constant" + ) + + wavs.append(audio_np) + num_frames = int( + (float(audio_np.shape[0]) / float(sr)) * float(self.info.token_fps) + ) + speech_strs.append(pad_token * max(1, int(num_frames))) + + audio_group_size = self.info.get_audio_group_size() + speech_inputs = speech_tokenizer( + speech_strs, + return_attention_mask=True, + return_token_type_ids=False, + padding=True, + pad_to_multiple_of=audio_group_size, + return_tensors="pt", + ) + + wav_inputs = feature_extractor( + wavs, + sampling_rate=sr, + return_attention_mask=True, + padding="max_length", + return_tensors="pt", + ) + + mm_inputs: dict[str, torch.Tensor] = { + "speech_ids": speech_inputs["input_ids"], + "speech_attention_mask": speech_inputs["attention_mask"], + "input_features": wav_inputs["input_features"], + "feature_attention_mask": wav_inputs["attention_mask"], + "feature_exist_mask": torch.ones((len(wavs),), dtype=torch.bool), + } + + return BatchFeature({"input_ids": input_ids, **mm_inputs}) + + def _hf_processor_applies_updates( + self, + prompt_text: str, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + tokenization_kwargs: Mapping[str, object], + ) -> bool: + return False + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return { + "speech_ids": MultiModalFieldConfig.batched("audio"), + "speech_attention_mask": MultiModalFieldConfig.batched("audio"), + "input_features": MultiModalFieldConfig.batched("audio"), + "feature_attention_mask": MultiModalFieldConfig.batched("audio"), + "feature_exist_mask": MultiModalFieldConfig.batched("audio"), + } + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = "<|AUDIO|>" + audio_token_id = vocab[audio_token] + + out_mm_data = out_mm_kwargs.get_data() + speech_attention_mask = out_mm_data.get("speech_attention_mask") + if speech_attention_mask is None: + audio_output_lengths: list[int] = [] + else: + assert isinstance(speech_attention_mask, torch.Tensor) + speech_lengths = speech_attention_mask.sum(-1) + group_size = self.info.get_audio_group_size() + audio_output_lengths = ( + (speech_lengths + group_size - 1) // group_size + ).tolist() + + def get_replacement_funaudiochat(item_idx: int): + num_features = ( + int(audio_output_lengths[item_idx]) if audio_output_lengths else 1 + ) + if num_features <= 0: + audios = mm_items.get_items("audio", AudioProcessorItems) + audio_len = audios.get_audio_length(item_idx) + raise ValueError( + f"The audio (len={audio_len}) is too short to be " + "represented inside the model" + ) + + audio_tokens = [audio_token_id] * num_features + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_funaudiochat, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + FunAudioChatMultiModalProcessor, + info=FunAudioChatProcessingInfo, + dummy_inputs=FunAudioChatDummyInputsBuilder, +) +class FunAudioChatForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + if modality.startswith("audio"): + return "<|audio_bos|><|AUDIO|><|audio_eos|>" + + raise ValueError("Only audio modality is supported") + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + multimodal_config = vllm_config.model_config.multimodal_config + self.config = config + self.multimodal_config = multimodal_config + self.quant_config = quant_config + + with self._mark_tower_model(vllm_config, "audio"): + self.continuous_audio_tower = FunAudioChatAudioEncoder(config.audio_config) + self.audio_tower = FunAudioChatDiscreteEncoder(config.audio_config) + + with self._mark_language_model(vllm_config): + self.language_model = init_vllm_registered_model( + vllm_config=vllm_config, + hf_config=config.text_config, + prefix=maybe_prefix(prefix, "language_model"), + architectures=["Qwen3ForCausalLM"], + ) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors + ) + + def get_language_model(self) -> torch.nn.Module: + return self.language_model + + def _get_continuous_audio_features( + self, + input_features: torch.Tensor, + feature_attention_mask: torch.Tensor, + speech_maxlen: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + # Align mask and features to avoid indexing errors when padding differs. + if ( + input_features.dim() == 3 + and feature_attention_mask.shape[1] != input_features.shape[-1] + ): + min_len = min( + int(feature_attention_mask.shape[1]), int(input_features.shape[-1]) + ) + feature_attention_mask = feature_attention_mask[:, :min_len] + input_features = input_features[:, :, :min_len] + + feature_lens = torch.sum(feature_attention_mask, dim=1) + + flat_features = input_features.permute(0, 2, 1)[ + feature_attention_mask.bool() + ].permute(1, 0) + + audio_feat_lengths, audio_output_lengths = ( + self.continuous_audio_tower._get_feat_extract_output_lengths(feature_lens) + ) + + audio_outputs = self.continuous_audio_tower( + flat_features, + feature_lens=feature_lens, + aftercnn_lens=audio_feat_lengths, + speech_maxlen=speech_maxlen, + ) + return audio_outputs.last_hidden_state, audio_output_lengths + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + speech_ids = kwargs.get("speech_ids") + speech_attention_mask = kwargs.get("speech_attention_mask") + input_features = kwargs.get("input_features") + feature_attention_mask = kwargs.get("feature_attention_mask") + feature_exist_mask = kwargs.get("feature_exist_mask") + + if speech_ids is None: + return [] + + pad_id = int(getattr(self.audio_tower, "padding_idx", 0)) + + if not isinstance(speech_ids, torch.Tensor): + if ( + isinstance(speech_ids, (list, tuple)) + and len(speech_ids) > 0 + and all(isinstance(t, torch.Tensor) for t in speech_ids) + ): + speech_ids_tensors = [] + for t in speech_ids: + if t.dim() == 2 and t.shape[0] == 1: + t = t.squeeze(0) + if t.dim() != 1: + raise TypeError( + "FunAudioChat speech_ids must be a 1D tensor per item " + f"(got shape={tuple(t.shape)})" + ) + speech_ids_tensors.append(t) + speech_ids = nn.utils.rnn.pad_sequence( + speech_ids_tensors, + batch_first=True, + padding_value=pad_id, + ) + else: + raise TypeError( + "FunAudioChat speech_ids must be a Tensor or a sequence of Tensors " + f"(got {type(speech_ids)})" + ) + + if speech_attention_mask is None: + speech_attention_mask = speech_ids.ne(pad_id).to(dtype=torch.int64) + + if not isinstance(speech_attention_mask, torch.Tensor): + if ( + isinstance(speech_attention_mask, (list, tuple)) + and len(speech_attention_mask) > 0 + and all(isinstance(t, torch.Tensor) for t in speech_attention_mask) + ): + mask_tensors = [] + for t in speech_attention_mask: + if t.dim() == 2 and t.shape[0] == 1: + t = t.squeeze(0) + if t.dim() != 1: + raise TypeError( + "FunAudioChat speech_attention_mask must be a 1D tensor " + f"per item (got shape={tuple(t.shape)})" + ) + mask_tensors.append(t) + speech_attention_mask = nn.utils.rnn.pad_sequence( + mask_tensors, + batch_first=True, + padding_value=0, + ) + else: + raise TypeError( + "FunAudioChat speech_attention_mask must be a Tensor or a " + f"sequence of Tensors (got {type(speech_attention_mask)})" + ) + + debug = os.getenv("VLLM_FUN_AUDIOCHAT_DEBUG", "") == "1" + if debug: + print( + f"[FunAudioChat] embed_multimodal speech_ids={tuple(speech_ids.shape)} " + f"speech_attention_mask={tuple(speech_attention_mask.shape)}", + flush=True, + ) + attn_impl = getattr( + self.continuous_audio_tower.config, "_attn_implementation", None + ) + print( + f"[FunAudioChat] audio_attn_impl={attn_impl}", + flush=True, + ) + if hasattr(self.continuous_audio_tower, "conv1"): + conv1_w = self.continuous_audio_tower.conv1.weight + print( + f"[FunAudioChat] conv1_w_norm={float(conv1_w.norm().item()):.6g}", + flush=True, + ) + try: + attn0 = self.continuous_audio_tower.layers[0].self_attn + q_norm = float(attn0.q_proj.weight.norm().item()) + k_norm = float(attn0.k_proj.weight.norm().item()) + v_norm = float(attn0.v_proj.weight.norm().item()) + o_norm = float(attn0.out_proj.weight.norm().item()) + print( + f"[FunAudioChat] attn0_q_norm={q_norm:.6g} " + f"k_norm={k_norm:.6g} " + f"v_norm={v_norm:.6g} " + f"o_norm={o_norm:.6g}", + flush=True, + ) + except Exception: + pass + if isinstance(input_features, torch.Tensor): + print( + f"[FunAudioChat] input_features={tuple(input_features.shape)}", + flush=True, + ) + if isinstance(feature_attention_mask, torch.Tensor): + print( + "[FunAudioChat] feature_attention_mask=" + f"{tuple(feature_attention_mask.shape)}", + flush=True, + ) + + group_size = int(self.audio_tower.group_size) + speech_maxlen = int(speech_ids.shape[-1]) + + # Ensure token length is divisible by group_size. + target_len = ((speech_maxlen + group_size - 1) // group_size) * group_size + if target_len > speech_maxlen: + pad_id = int(self.audio_tower.padding_idx) + pad_len = target_len - speech_maxlen + speech_ids = nn.functional.pad(speech_ids, (0, pad_len), value=pad_id) + speech_attention_mask = nn.functional.pad( + speech_attention_mask, (0, pad_len), value=0 + ) + speech_maxlen = int(speech_ids.shape[-1]) + + continuous_audio_features = None + continuous_audio_output_lengths = None + if input_features is not None and feature_attention_mask is not None: + assert isinstance(input_features, torch.Tensor) + assert isinstance(feature_attention_mask, torch.Tensor) + continuous_audio_features, continuous_audio_output_lengths = ( + self._get_continuous_audio_features( + input_features=input_features, + feature_attention_mask=feature_attention_mask, + speech_maxlen=speech_maxlen, + ) + ) + + if feature_exist_mask is None: + feature_exist_mask = torch.ones( + (speech_ids.shape[0],), dtype=torch.bool, device=speech_ids.device + ) + assert isinstance(feature_exist_mask, torch.Tensor) + + audio_features = self.audio_tower( + speech_ids, + continuous_audio_features=continuous_audio_features, + continuous_audio_output_lengths=continuous_audio_output_lengths, + feature_exist_mask=feature_exist_mask, + ) + + _, audio_output_lengths = self.audio_tower._get_feat_extract_output_lengths( + speech_attention_mask.sum(-1) + ) + lengths = audio_output_lengths.tolist() + + embeds = tuple( + audio_features[i, : int(length)] for i, length in enumerate(lengths) + ) + if debug: + embed_lens = [int(t.shape[0]) for t in embeds] + print(f"[FunAudioChat] embed_multimodal out_lens={embed_lens}", flush=True) + if embeds: + t0 = embeds[0] + print( + f"[FunAudioChat] embed0 dtype={t0.dtype} device={t0.device} " + f"nan={bool(torch.isnan(t0).any())} " + f"norm={float(t0.norm().item()):.6g}", + flush=True, + ) + dump_path = os.getenv("VLLM_FUN_AUDIOCHAT_DUMP_PATH", "") + if ( + dump_path + and speech_ids.shape[0] == 1 + and len(embeds) == 1 + and embed_lens[0] > 10 + ): + if not os.path.exists(dump_path): + np.save(dump_path, embeds[0].detach().float().cpu().numpy()) + print(f"[FunAudioChat] dumped embeds to {dump_path}", flush=True) + cont_path = dump_path.replace(".npy", "_cont.npy") + if continuous_audio_features is not None and not os.path.exists( + cont_path + ): + np.save( + cont_path, + continuous_audio_features.detach().float().cpu().numpy(), + ) + print( + f"[FunAudioChat] dumped continuous to {cont_path}", flush=True + ) + return embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: IntermediateTensors | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs: object, + ) -> torch.Tensor | IntermediateTensors: + del kwargs + if intermediate_tensors is not None: + inputs_embeds = None + + return self.language_model.model( + input_ids, + positions, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor | None: + return self.language_model.compute_logits(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader(self, skip_prefixes=["audio_invert_tower."]) + return loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 16edade2b..3ae0716bb 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -312,6 +312,10 @@ _MULTIMODAL_MODELS = { "ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration", ), + "FunAudioChatForConditionalGeneration": ( + "funaudiochat", + "FunAudioChatForConditionalGeneration", + ), "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), "Gemma3ForConditionalGeneration": ("gemma3_mm", "Gemma3ForConditionalGeneration"), # noqa: E501 "Gemma3nForConditionalGeneration": ( diff --git a/vllm/multimodal/audio.py b/vllm/multimodal/audio.py index f0a499f8d..cccf7d1a6 100644 --- a/vllm/multimodal/audio.py +++ b/vllm/multimodal/audio.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math from dataclasses import dataclass from enum import Enum from typing import Literal @@ -195,6 +196,13 @@ class AudioResampler: raise RuntimeError( "Audio resampling is not supported when `target_sr` is not provided" ) + if math.isclose( + float(orig_sr), + float(self.target_sr), + rel_tol=0.0, + abs_tol=1e-6, + ): + return audio if self.method == "librosa": return resample_audio_librosa( audio, orig_sr=orig_sr, target_sr=self.target_sr diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 4646948d4..b8edc5769 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -77,6 +77,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( deepseek_vl_v2="DeepseekVLV2Config", deepseek_v32="DeepseekV3Config", flex_olmo="FlexOlmoConfig", + funaudiochat="FunAudioChatConfig", hunyuan_vl="HunYuanVLConfig", isaac="IsaacConfig", kimi_linear="KimiLinearConfig", diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index bfb9c1758..0c1f665fc 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -22,6 +22,8 @@ _CLASS_TO_MODULE: dict[str, str] = { "DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr", "EAGLEConfig": "vllm.transformers_utils.configs.eagle", "FlexOlmoConfig": "vllm.transformers_utils.configs.flex_olmo", + "FunAudioChatConfig": "vllm.transformers_utils.configs.funaudiochat", + "FunAudioChatAudioEncoderConfig": "vllm.transformers_utils.configs.funaudiochat", "HunYuanVLConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLTextConfig": "vllm.transformers_utils.configs.hunyuan_vl", "HunYuanVLVisionConfig": "vllm.transformers_utils.configs.hunyuan_vl", @@ -65,6 +67,8 @@ __all__ = [ "DotsOCRConfig", "EAGLEConfig", "FlexOlmoConfig", + "FunAudioChatConfig", + "FunAudioChatAudioEncoderConfig", "HunYuanVLConfig", "HunYuanVLTextConfig", "HunYuanVLVisionConfig", diff --git a/vllm/transformers_utils/configs/funaudiochat.py b/vllm/transformers_utils/configs/funaudiochat.py new file mode 100644 index 000000000..04505b273 --- /dev/null +++ b/vllm/transformers_utils/configs/funaudiochat.py @@ -0,0 +1,124 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + +from transformers import PretrainedConfig + +# NOTE: Temporary shim for FunAudioChat checkpoints. +# These checkpoints use `model_type="funaudiochat"`, which is not currently +# recognized by released Transformers, and the public checkpoint does not +# provide an `auto_map` to enable `trust_remote_code=True`. +# Remove this file once Transformers adds native support (or the checkpoint +# provides an `auto_map`) and vLLM can rely on `AutoConfig.from_pretrained()`. + + +class FunAudioChatAudioEncoderConfig(PretrainedConfig): + model_type = "funaudiochat_audio_encoder" + + def __init__( + self, + _attn_implementation: str | None = None, + num_mel_bins: int = 128, + encoder_layers: int = 32, + encoder_attention_heads: int = 20, + encoder_ffn_dim: int = 5120, + d_model: int = 1280, + dropout: float = 0.0, + attention_dropout: float = 0.0, + activation_function: str = "gelu", + activation_dropout: float = 0.0, + scale_embedding: bool = False, + initializer_range: float = 0.02, + max_source_positions: int = 1500, + n_window: int = 100, + output_dim: int = 3584, + bos_token_id: int | None = None, + codebook_size: int | None = None, + continuous_features_mode: str = "replace", + crq_transformer_config: dict | None = None, + eos_token_id: int | None = None, + group_size: int = 5, + enable_audio_invert_tower: bool = True, + pad_token_id: int | None = None, + **kwargs, + ) -> None: + attn_impl = kwargs.pop("_attn_implementation", None) or _attn_implementation + super().__init__(**kwargs) + # Match HF default for attention implementation selection. + self._attn_implementation = attn_impl or "sdpa" + + self.num_mel_bins = num_mel_bins + self.d_model = d_model + self.encoder_layers = encoder_layers + self.encoder_attention_heads = encoder_attention_heads + self.encoder_ffn_dim = encoder_ffn_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.activation_function = activation_function + self.activation_dropout = activation_dropout + self.num_hidden_layers = encoder_layers + self.initializer_range = initializer_range + self.scale_embedding = scale_embedding + self.max_source_positions = max_source_positions + self.n_window = n_window + self.output_dim = output_dim + + self.bos_token_id = bos_token_id + self.codebook_size = codebook_size + self.continuous_features_mode = continuous_features_mode + self.crq_transformer_config = crq_transformer_config + self.eos_token_id = eos_token_id + self.group_size = group_size + self.enable_audio_invert_tower = enable_audio_invert_tower + self.pad_token_id = pad_token_id + + +class FunAudioChatConfig(PretrainedConfig): + model_type = "funaudiochat" + attribute_map = { + "audio_token_id": "audio_token_index", + } + + def __init__( + self, + audio_config: PretrainedConfig | dict | None = None, + text_config: PretrainedConfig | dict | None = None, + audio_token_index: int = 151646, + ignore_index: int = -100, + hidden_size: int | None = None, + **kwargs, + ) -> None: + self.audio_token_index = audio_token_index + self.ignore_index = ignore_index + + if isinstance(audio_config, dict): + audio_config.setdefault( + "model_type", FunAudioChatAudioEncoderConfig.model_type + ) + audio_config = FunAudioChatAudioEncoderConfig(**audio_config) + elif audio_config is None: + audio_config = FunAudioChatAudioEncoderConfig() + self.audio_config = audio_config + + if isinstance(text_config, dict): + # Default to qwen2 for backwards compatibility; FunAudioChat uses + # qwen3 in practice for recent checkpoints. + text_config.setdefault("model_type", "qwen2") + import transformers + + text_cls = transformers.CONFIG_MAPPING[text_config["model_type"]] + text_config = text_cls(**text_config) + elif text_config is None: + import transformers + + text_config = transformers.CONFIG_MAPPING["qwen2"]() + self.text_config = text_config + + self.hidden_size = ( + int(self.text_config.hidden_size) + if hidden_size is None + else int(hidden_size) + ) + + super().__init__(**kwargs)