diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 7ff9531c5..7f20d2052 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -790,6 +790,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|-------------------|----------------------|---------------------------| +| `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ | | `GraniteSpeechForConditionalGeneration` | Granite Speech | `ibm-granite/granite-speech-3.3-2b`, `ibm-granite/granite-speech-3.3-8b`, etc. | ✅︎ | ✅︎ | diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index 966bfd2a4..478a0a7ea 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -26,7 +26,9 @@ from openai import AsyncOpenAI, OpenAI from vllm.assets.audio import AudioAsset -def sync_openai(audio_path: str, client: OpenAI, model: str): +def sync_openai( + audio_path: str, client: OpenAI, model: str, *, repetition_penalty: float = 1.3 +): """ Perform synchronous transcription using OpenAI-compatible API. """ @@ -40,7 +42,7 @@ def sync_openai(audio_path: str, client: OpenAI, model: str): # Additional sampling params not provided by OpenAI API. extra_body=dict( seed=4419, - repetition_penalty=1.3, + repetition_penalty=repetition_penalty, ), ) print("transcription result [sync]:", transcription.text) @@ -129,7 +131,12 @@ def main(args): print(f"Using model: {model}") # Run the synchronous function - sync_openai(args.audio_path if args.audio_path else mary_had_lamb, client, model) + sync_openai( + audio_path=args.audio_path if args.audio_path else mary_had_lamb, + client=client, + model=model, + repetition_penalty=args.repetition_penalty, + ) # Run the asynchronous function if "openai" in model: @@ -161,5 +168,11 @@ if __name__ == "__main__": default=None, help="The path to the audio file to transcribe.", ) + parser.add_argument( + "--repetition_penalty", + type=float, + default=1.3, + help="repetition penalty", + ) args = parser.parse_args() main(args) diff --git a/tests/models/registry.py b/tests/models/registry.py index d2c67cf7e..abc621d8e 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -713,6 +713,10 @@ _MULTIMODAL_EXAMPLE_MODELS = { "baidu/ERNIE-4.5-VL-28B-A3B-PT", trust_remote_code=True, ), + "FunASRForConditionalGeneration": _HfExamplesInfo( + "allendou/Fun-ASR-Nano-2512-vllm", + is_available_online=False, + ), "FunAudioChatForConditionalGeneration": _HfExamplesInfo( "funaudiochat", is_available_online=False ), diff --git a/vllm/model_executor/models/funasr.py b/vllm/model_executor/models/funasr.py new file mode 100644 index 000000000..b4d4fb5b7 --- /dev/null +++ b/vllm/model_executor/models/funasr.py @@ -0,0 +1,1057 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, cast + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import ( + BatchFeature, + Qwen3Config, +) + +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.attention.mm_encoder_attention import ( + MMEncoderAttention, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.whisper_utils import ( + ISO639_1_SUPPORTED_LANGS, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.processors.funasr_processor import FunASRFeatureExtractor +from vllm.utils.jsontree import json_map_leaves +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsTranscription, + _require_is_multimodal, +) +from .qwen3 import Qwen3Model +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) + +logger = init_logger(__name__) + + +def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): + if maxlen is None: + maxlen = lengths.max() + row_vector = torch.arange(0, maxlen, 1).to(lengths.device) + matrix = torch.unsqueeze(lengths, dim=-1) + mask = row_vector < matrix + mask = mask.detach() + + return mask.type(dtype).to(device) if device is not None else mask.type(dtype) + + +class LayerNorm(torch.nn.LayerNorm): + def __init__(self, nout, dim=-1): + super().__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x: torch.Tensor): + if self.dim == -1: + return super().forward(x) + return super().forward(x.transpose(self.dim, -1)).transpose(self.dim, -1) + + +class EncoderLayerSANM(nn.Module): + def __init__( + self, + in_size: int, + size: int, + self_attn: nn.Module, + feed_forward: nn.Module, + normalize_before=True, + ): + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(in_size) + self.norm2 = LayerNorm(size) + self.in_size = in_size + self.size = size + self.normalize_before = normalize_before + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor | None = None, + cache=None, + mask_shfit_chunk=None, + mask_att_chunk_encoder=None, + ): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + + if self.in_size == self.size: + hidden_states = residual + self.self_attn( + hidden_states, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ) + else: + hidden_states = self.self_attn( + hidden_states, + mask, + mask_shfit_chunk=mask_shfit_chunk, + mask_att_chunk_encoder=mask_att_chunk_encoder, + ) + + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = residual + self.feed_forward(hidden_states) + + return hidden_states, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder + + +class MultiHeadedAttentionSANM(nn.Module): + def __init__( + self, + n_head: int, + in_feat: int, + n_feat: int, + kernel_size: int, + sanm_shift: int = 0, + ): + super().__init__() + assert n_feat % n_head == 0 + # We assume d_v always equals d_k + self.d_k = n_feat // n_head + self.h = n_head + self.out_proj = ReplicatedLinear( + input_size=n_feat, + output_size=n_feat, + bias=True, + ) + self.linear_q_k_v = ReplicatedLinear( + input_size=in_feat, + output_size=n_feat * 3, + bias=True, + ) + self.attn = None + + self.fsmn_block = nn.Conv1d( + n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False + ) + # padding + left_padding = (kernel_size - 1) // 2 + if sanm_shift > 0: + left_padding = left_padding + sanm_shift + right_padding = kernel_size - 1 - left_padding + self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0) + + def forward_fsmn( + self, + inputs: torch.Tensor, + mask: torch.Tensor, + mask_shfit_chunk: torch.Tensor = None, + ): + b, t, d = inputs.size() + if mask is not None: + mask = torch.reshape(mask, (b, -1, 1)) + if mask_shfit_chunk is not None: + mask = mask * mask_shfit_chunk + inputs = inputs * mask + + x = inputs.transpose(1, 2) + x = self.pad_fn(x) + x = self.fsmn_block(x) + x = x.transpose(1, 2) + x += inputs + if mask is not None: + x = x * mask + return x + + def forward_qkv(self, x: torch.Tensor): + b, t, d = x.size() + + q_k_v, _ = self.linear_q_k_v(x) + q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1) + q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) + k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) + v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) + + return q_h, k_h, v_h, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor, + mask_att_chunk_encoder: torch.Tensor = None, + ): + n_batch = value.size(0) + if mask is not None: + if mask_att_chunk_encoder is not None: + mask = mask * mask_att_chunk_encoder + + mask = mask.unsqueeze(1).eq(0) + + min_value = -float("inf") + scores = scores.masked_fill(mask, min_value) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) + else: + attn = torch.softmax(scores, dim=-1) + + p_attn = attn + x = torch.matmul(p_attn, value) + x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) + + out, _ = self.out_proj(x) + return out + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor, + mask_shfit_chunk: torch.Tensor = None, + mask_att_chunk_encoder: torch.Tensor = None, + ): + q_h, k_h, v_h, v = self.forward_qkv(hidden_states) + fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk) + q_h = q_h * self.d_k ** (-0.5) + scores = torch.matmul(q_h, k_h.transpose(-2, -1)) + att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) + return att_outs + fsmn_memory + + +class SinusoidalPositionEncoder(torch.nn.Module): + def __init__(self, d_model=80): + super().__init__() + + def encode( + self, + positions: torch.Tensor = None, + depth: int = None, + dtype: torch.dtype = torch.float32, + ): + batch_size = positions.size(0) + positions = positions.type(dtype) + device = positions.device + log_timescale_increment = torch.log( + torch.tensor([10000], dtype=dtype, device=device) + ) / (depth / 2 - 1) + inv_timescales = torch.exp( + torch.arange(depth / 2, device=device).type(dtype) + * (-log_timescale_increment) + ) + inv_timescales = torch.reshape(inv_timescales, [batch_size, -1]) + scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape( + inv_timescales, [1, 1, -1] + ) + encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2) + return encoding.type(dtype) + + def forward(self, hidden_states: torch.Tensor): + batch_size, timesteps, input_dim = hidden_states.size() + positions = torch.arange(1, timesteps + 1, device=hidden_states.device)[None, :] + position_encoding = self.encode(positions, input_dim, hidden_states.dtype).to( + hidden_states.device + ) + + return hidden_states + position_encoding + + +class SenseVoiceEncoderSmall(nn.Module): + def __init__( + self, + input_size: int, + output_size: int = 256, + attention_heads: int = 4, + linear_units: int = 2048, + num_blocks: int = 6, + tp_blocks: int = 0, + attention_dropout_rate: float = 0.0, + normalize_before: bool = True, + kernel_size: int = 11, + sanm_shift: int = 0, + **kwargs, + ): + super().__init__() + self._output_size = output_size + self.embed = SinusoidalPositionEncoder() + + self.normalize_before = normalize_before + + positionwise_layer = PositionwiseFeedForward + positionwise_layer_args = ( + output_size, + linear_units, + ) + + encoder_selfattn_layer = MultiHeadedAttentionSANM + encoder_selfattn_layer_args0 = ( + attention_heads, + input_size, + output_size, + kernel_size, + sanm_shift, + ) + encoder_selfattn_layer_args = ( + attention_heads, + output_size, + output_size, + kernel_size, + sanm_shift, + ) + + self.encoders0 = nn.ModuleList( + [ + EncoderLayerSANM( + input_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args0), + positionwise_layer(*positionwise_layer_args), + ) + for i in range(1) + ] + ) + self.encoders = nn.ModuleList( + [ + EncoderLayerSANM( + output_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + ) + for i in range(num_blocks - 1) + ] + ) + + self.tp_encoders = nn.ModuleList( + [ + EncoderLayerSANM( + output_size, + output_size, + encoder_selfattn_layer(*encoder_selfattn_layer_args), + positionwise_layer(*positionwise_layer_args), + ) + for i in range(tp_blocks) + ] + ) + + self.after_norm = LayerNorm(output_size) + + self.tp_norm = LayerNorm(output_size) + + def output_size(self) -> int: + return self._output_size + + def forward( + self, + xs_pad: torch.Tensor, + ilens: torch.Tensor, + ): + maxlen = xs_pad.shape[1] + masks = sequence_mask( + ilens, maxlen=maxlen, dtype=ilens.dtype, device=ilens.device + )[:, None, :] + + xs_pad *= self.output_size() ** 0.5 + + xs_pad = self.embed(xs_pad) + + for layer_idx, encoder_layer in enumerate(self.encoders0): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + for layer_idx, encoder_layer in enumerate(self.encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + xs_pad = self.after_norm(xs_pad) + + olens = masks.squeeze(1).sum(1).int() + + for layer_idx, encoder_layer in enumerate(self.tp_encoders): + encoder_outs = encoder_layer(xs_pad, masks) + xs_pad, masks = encoder_outs[0], encoder_outs[1] + + xs_pad = self.tp_norm(xs_pad) + return xs_pad, olens + + +class PositionwiseFeedForward(nn.Module): + def __init__(self, idim: int, hidden_units: int): + super().__init__() + self.w_1 = ColumnParallelLinear( + input_size=idim, + output_size=hidden_units, + bias=True, + ) + self.w_2 = RowParallelLinear( + input_size=hidden_units, + output_size=idim, + bias=True, + ) + self.activation = _ACTIVATION_REGISTRY["relu"] + + def forward(self, hidden_states: torch.Tensor): + hidden_states, _ = self.w_1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states, _ = self.w_2(hidden_states) + return hidden_states + + +class EncoderLayer(nn.Module): + def __init__( + self, + size: int, + self_attn: nn.Module, + feed_forward: nn.Module, + ): + super().__init__() + self.self_attn = self_attn + self.feed_forward = feed_forward + self.norm1 = LayerNorm(size) + self.norm2 = LayerNorm(size) + + def forward(self, hidden_states: torch.Tensor): + residual = hidden_states + hidden_states = self.norm1(hidden_states) + hidden_states = residual + self.self_attn(hidden_states, None, None) + residual = hidden_states + hidden_states = self.norm2(hidden_states) + hidden_states = residual + self.feed_forward(hidden_states) + + return hidden_states + + +class FunASRAudioAttention(nn.Module): + def __init__( + self, + num_heads: int, + embed_dim: int, + prefix: str = "", + ): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = self.embed_dim // self.num_heads + tp_size = get_tensor_model_parallel_world_size() + self.num_local_heads = self.num_heads // tp_size + + if (self.head_dim * self.num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`: {self.num_heads})." + ) + + self.scaling = self.head_dim**-0.5 + + self.qkv = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + total_num_kv_heads=self.num_heads, + bias=True, + prefix=f"{prefix}.qkv", + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + bias=True, + prefix=f"{prefix}.out_proj", + ) + + self.attn = MMEncoderAttention( + num_heads=self.num_local_heads, + head_size=self.head_dim, + scale=self.scaling, + ) + + def forward( + self, + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + max_seqlen: torch.Tensor | None, + ) -> torch.Tensor: + bs, seq_length, _ = hidden_states.size() + qkv, _ = self.qkv(hidden_states) + q, k, v = qkv.chunk(3, dim=-1) + q = q.view(bs, seq_length, -1, self.head_dim) + k = k.view(bs, seq_length, -1, self.head_dim) + v = v.view(bs, seq_length, -1, self.head_dim) + + attn_output = self.attn( + query=q, + key=k, + value=v, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + ) + + attn_output = attn_output.view(bs, seq_length, -1) + output, _ = self.out_proj(attn_output) + return output + + +class Transformer(nn.Module): + def __init__( + self, + downsample_rate=2, + encoder_dim=1280, + llm_dim=4096, + ffn_dim: int = 2048, + prefix: str = "", + **kwargs, + ): + super().__init__() + self.k = downsample_rate + self.encoder_dim = encoder_dim + self.llm_dim = llm_dim + self.linear1 = ColumnParallelLinear( + input_size=self.encoder_dim * self.k, + output_size=ffn_dim, + bias=True, + ) + self.relu = nn.ReLU() + self.linear2 = RowParallelLinear( + input_size=ffn_dim, + output_size=self.llm_dim, + bias=True, + ) + + self.blocks = None + if kwargs.get("n_layer", 2) > 0: + self.blocks = nn.ModuleList( + [ + EncoderLayer( + llm_dim, + FunASRAudioAttention( + kwargs.get("attention_heads", 8), + llm_dim, + prefix=f"{prefix}.self_attn", + ), + PositionwiseFeedForward( + llm_dim, + llm_dim // 4, + ), + ) + for _ in range(kwargs.get("n_layer", 2)) + ] + ) + + def forward(self, hidden_states: torch.Tensor, ilens: int = 0): + batch_size, seq_len, dim = hidden_states.size() + chunk_num = (seq_len - 1) // self.k + 1 + pad_num = chunk_num * self.k - seq_len + hidden_states = F.pad(hidden_states, (0, 0, 0, pad_num, 0, 0), value=0.0) + seq_len = hidden_states.size(1) + + hidden_states = hidden_states.contiguous() + hidden_states = hidden_states.view(batch_size, chunk_num, dim * self.k) + hidden_states, _ = self.linear1(hidden_states) + hidden_states = self.relu(hidden_states) + hidden_states, _ = self.linear2(hidden_states) + + olens = None + olens = (ilens - 1) // self.k + 1 + + if self.blocks is not None: + for layer, block in enumerate(self.blocks): + hidden_states = block(hidden_states) + return hidden_states, olens + + +class FunASRAudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[ + list[torch.Tensor] | None, + TensorShape("b", "nmb", "t"), + ] + speech_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] + + +class FunASREncoder(nn.Module): + def __init__( + self, *, vllm_config: VllmConfig, prefix: str = "", init_in_fp32: bool = False + ): + super().__init__() + self.audio_encoder = SenseVoiceEncoderSmall( + input_size=560, **vllm_config.model_config.hf_config.audio_encoder_conf + ) + self.audio_adaptor = Transformer( + downsample_rate=1, + use_low_frame_rate=True, + ffn_dim=2048, + llm_dim=1024, + encoder_dim=512, + n_layer=2, + freeze=True, + prefix=maybe_prefix(prefix, "audio_encoder"), + ) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + """Load weights with mapping from HuggingFace format.""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("self_attn.qkv.", "self_attn.q_proj.", "q"), + ("self_attn.qkv.", "self_attn.k_proj.", "k"), + ("self_attn.qkv.", "self_attn.v_proj.", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict.get(name) + if param is not None: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class FunASRModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = FunASREncoder( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "encoder") + ) + self.decoder = Qwen3Model( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder") + ) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def get_encoder_outputs( + self, + speech: torch.Tensor | list[torch.Tensor] | None, + speech_lengths: torch.Tensor | list[torch.Tensor] | None, + ) -> torch.Tensor | None: + self.feat_permute = False + + if self.feat_permute: + encoder_out, encoder_out_lens = self.encoder.audio_encoder( + speech.permute(0, 2, 1), speech_lengths + ) + else: + encoder_out, encoder_out_lens = self.encoder.audio_encoder( + speech, speech_lengths + ) + + encoder_out, encoder_out_lens = self.encoder.audio_adaptor( + encoder_out, encoder_out_lens + ) + return encoder_out + + +class FunASRProcessingInfo(BaseProcessingInfo): + def get_hf_config(self) -> Qwen3Config: + return self.ctx.get_hf_config(Qwen3Config) + + @property + def skip_prompt_length_check(self) -> bool: + return True # Because the encoder prompt is padded + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": 1} + + def get_feature_extractor(self, **kwargs: object) -> FunASRFeatureExtractor: + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, FunASRFeatureExtractor) + return feature_extractor + + def get_target_channels(self) -> int: + return 1 + + def get_num_audio_tokens(self) -> int: + return self.get_hf_config().max_source_positions + + +class FunASRDummyInputsBuilder(BaseDummyInputsBuilder[FunASRProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + return "<|AUDIO|>" * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions] | None = None, + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + audio_overrides = mm_options.get("audio") if mm_options else None + + return { + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) + } + + +class FunASRMultiModalProcessor(BaseMultiModalProcessor[FunASRProcessingInfo]): + def _get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.info.get_feature_extractor() + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + target_channels=self.info.get_target_channels(), + ) + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + speech_lengths=MultiModalFieldConfig.batched("audio"), + fake_token_len=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + # Use getattr with default to be compatible with transformers<4.48 + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + + audio_token_id = vocab[audio_token] + + out_mm_data = out_mm_kwargs.get_data() + + fake_token_len = out_mm_data.get("fake_token_len") + if fake_token_len is None: + audio_output_lengths = [] + else: + assert isinstance(fake_token_len, torch.Tensor) + + audio_output_lengths = fake_token_len.tolist() + + def get_replacement_qwen2_audio(item_idx: int): + if audio_output_lengths: + num_features = audio_output_lengths[item_idx] + else: + audio_embeds = out_mm_data["audio_embeds"][item_idx] + assert len(audio_embeds.shape) == 2, "audio_embeds must be a 2D tensor" + num_features = audio_embeds.shape[0] + + audio_tokens = [audio_token_id] * num_features + + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=audio_token, + replacement=get_replacement_qwen2_audio, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + FunASRMultiModalProcessor, + info=FunASRProcessingInfo, + dummy_inputs=FunASRDummyInputsBuilder, +) +class FunASRForConditionalGeneration( + nn.Module, SupportsTranscription, SupportsMultiModal +): + packed_modules_mapping = { + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "linear_q.": "q_proj.", + "linear_k.": "k_proj.", + "linear_v.": "v_proj.", + "linear_out.": "out_proj.", + } + ) + + supports_transcription_only = True + supports_segment_timestamp = True + supported_languages = ISO639_1_SUPPORTED_LANGS + + @classmethod + def validate_language(cls, language: str | None) -> str | None: + if language is None: + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + logger.warning( + "Defaulting to language='en'. If you wish to transcribe " + "audio in a different language, pass the `language` field " + "in the TranscriptionRequest." + ) + language = "en" + return super().validate_language(language) + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + if language is None: + raise ValueError( + "Language must be specified when creating the funasr prompt" + ) + + funasr_prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n语音转写:<|AUDIO|><|im_end|>\n<|im_start|>assistant\n" # noqa: E501 + prompt = { + "prompt": funasr_prompt, + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + } + return cast(PromptType, prompt) + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + processor = cached_processor_from_config(model_config) + + return SpeechToTextConfig( + max_audio_clip_s=processor.feature_extractor.chunk_length, + sample_rate=processor.feature_extractor.sampling_rate, + ) + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> int | None: + processor = cached_processor_from_config(model_config) + hop_length = processor.feature_extractor.hop_length + assert hop_length is not None + return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = FunASRModel( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + + if config.tie_word_embeddings: + self.lm_head = self.model.decoder.embed_tokens + else: + self.lm_head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + decoder_outputs = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def get_language_model(self) -> torch.nn.Module: + return self.model.decoder + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + + speech = audio_input["input_features"] + speech_lengths = audio_input["speech_lengths"] + enc_output = self.model.get_encoder_outputs( + speech=speech, speech_lengths=speech_lengths + ) + + return enc_output + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self.model.decoder.embed_input_ids(input_ids) + + return _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=_require_is_multimodal(is_multimodal), + ) + + def _parse_and_validate_audio_input(self, **kwargs: object) -> FunASRAudioInputs: + input_features = kwargs.pop("input_features", None) + speech_lengths = kwargs.pop("speech_lengths", None) + + if input_features is not None: + input_features = json_map_leaves(lambda x: x.to(self.dtype), input_features) + + if speech_lengths is not None: + speech_lengths = json_map_leaves(lambda x: x.to(self.dtype), speech_lengths) + + return FunASRAudioInputs( + input_features=input_features, speech_lengths=speech_lengths + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, + ) + + # add fake zeros bias for k_proj to state_dict + weights = _create_fake_bias_for_k_proj(weights) + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) + + +def _create_fake_bias_for_k_proj( + weights: Iterable[tuple[str, torch.Tensor]], +) -> Iterable[tuple[str, torch.Tensor]]: + """ + Create full zeros bias for k_proj weight in self-attn and x-attn layers. + So that the bias for k_proj in qkv_proj can be initialized with zeros. + """ + for name, weight in weights: + if name.endswith(".k_proj.weight"): + bias = torch.zeros(weight.size(0)) + bias_name = name.replace("weight", "bias") + yield from [(name, weight), (bias_name, bias)] + else: + yield name, weight diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 1871591c9..59fcd9117 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -325,6 +325,7 @@ _MULTIMODAL_MODELS = { "ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration", ), + "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"), # noqa: E501 "FunAudioChatForConditionalGeneration": ( "funaudiochat", "FunAudioChatForConditionalGeneration", diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index af25dbe4c..d726fd39a 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -10,6 +10,7 @@ reasons: from vllm.transformers_utils.processors.bagel import BagelProcessor from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor +from vllm.transformers_utils.processors.funasr_processor import FunASRProcessor from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor from vllm.transformers_utils.processors.ovis import OvisProcessor @@ -18,6 +19,7 @@ from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor __all__ = [ "BagelProcessor", "DeepseekVLV2Processor", + "FunASRProcessor", "HunYuanVLProcessor", "HunYuanVLImageProcessor", "OvisProcessor", diff --git a/vllm/transformers_utils/processors/funasr_processor.py b/vllm/transformers_utils/processors/funasr_processor.py new file mode 100644 index 000000000..4807c87d3 --- /dev/null +++ b/vllm/transformers_utils/processors/funasr_processor.py @@ -0,0 +1,504 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import numpy as np +import torch +import torch.nn as nn +import torchaudio.compliance.kaldi as kaldi +from torch.nn.utils.rnn import pad_sequence +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + BatchFeature, +) +from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor +from transformers.processing_utils import ProcessorMixin +from transformers.utils import TensorType + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def apply_cmvn(inputs, cmvn): # noqa + """ + Apply CMVN with mvn data + """ + + device = inputs.device + # dtype = inputs.dtype + frame, dim = inputs.shape + + means = cmvn[0:1, :dim] + vars = cmvn[1:2, :dim] + inputs += means.to(device) + inputs *= vars.to(device) + + return inputs.type(torch.float32) + + +def apply_lfr(inputs, lfr_m, lfr_n): + # LFR_inputs = [] + T = inputs.shape[0] + T_lfr = int(np.ceil(T / lfr_n)) + left_padding = inputs[0].repeat((lfr_m - 1) // 2, 1) + inputs = torch.vstack((left_padding, inputs)) + T = T + (lfr_m - 1) // 2 + feat_dim = inputs.shape[-1] + strides = (lfr_n * feat_dim, 1) + sizes = (T_lfr, lfr_m * feat_dim) + last_idx = (T - lfr_m) // lfr_n + 1 + num_padding = lfr_m - (T - last_idx * lfr_n) + if num_padding > 0: + num_padding = ( + (2 * lfr_m - 2 * T + (T_lfr - 1 + last_idx) * lfr_n) + / 2 + * (T_lfr - last_idx) + ) + inputs = torch.vstack([inputs] + [inputs[-1:]] * int(num_padding)) + LFR_outputs = inputs.as_strided(sizes, strides) + return LFR_outputs.clone().type(torch.float32) + + +def load_cmvn(cmvn_file): + with open(cmvn_file, encoding="utf-8") as f: + lines = f.readlines() + means_list = [] + vars_list = [] + for i in range(len(lines)): + line_item = lines[i].split() + if line_item[0] == "": + line_item = lines[i + 1].split() + if line_item[0] == "": + add_shift_line = line_item[3 : (len(line_item) - 1)] + means_list = list(add_shift_line) + continue + elif line_item[0] == "": + line_item = lines[i + 1].split() + if line_item[0] == "": + rescale_line = line_item[3 : (len(line_item) - 1)] + vars_list = list(rescale_line) + continue + means = np.array(means_list).astype(np.float32) + vars = np.array(vars_list).astype(np.float32) + cmvn = np.array([means, vars]) + cmvn = torch.as_tensor(cmvn, dtype=torch.float32) + return cmvn + + +class WavFrontend(nn.Module): + """Conventional frontend structure for ASR.""" + + def __init__( + self, + cmvn_file: str = "null", + fs: int = 16000, + window: str = "hamming", + n_mels: int = 80, + frame_length: int = 25, + frame_shift: int = 10, + filter_length_min: int = -1, + filter_length_max: int = -1, + lfr_m: int = 1, + lfr_n: int = 1, + dither: float = 1.0, + snip_edges: bool = True, + upsacle_samples: bool = True, + **kwargs, + ): + super().__init__() + self.fs = fs + self.window = window + self.n_mels = n_mels + self.frame_length = frame_length + self.frame_shift = frame_shift + self.filter_length_min = filter_length_min + self.filter_length_max = filter_length_max + self.lfr_m = lfr_m + self.lfr_n = lfr_n + self.cmvn_file = cmvn_file + self.dither = dither + self.snip_edges = snip_edges + self.upsacle_samples = upsacle_samples + self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file) + + def output_size(self) -> int: + return self.n_mels * self.lfr_m + + def forward( + self, + input: torch.Tensor, + input_lengths, + **kwargs, + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = input.size(0) + feats = [] + feats_lens = [] + for i in range(batch_size): + waveform_length = input_lengths[i] + waveform = input[i][:waveform_length] + if self.upsacle_samples: + waveform = waveform * (1 << 15) + waveform = waveform.unsqueeze(0) + mat = kaldi.fbank( + waveform, + num_mel_bins=self.n_mels, + frame_length=min(self.frame_length, waveform_length / self.fs * 1000), + frame_shift=self.frame_shift, + dither=self.dither, + energy_floor=0.0, + window_type=self.window, + sample_frequency=self.fs, + snip_edges=self.snip_edges, + ) + + if self.lfr_m != 1 or self.lfr_n != 1: + mat = apply_lfr(mat, self.lfr_m, self.lfr_n) + if self.cmvn is not None: + mat = apply_cmvn(mat, self.cmvn) + feat_length = mat.size(0) + feats.append(mat) + feats_lens.append(feat_length) + + feats_lens = torch.as_tensor(feats_lens) + if batch_size == 1: + feats_pad = feats[0][None, :, :] + else: + feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) + return feats_pad, feats_lens + + def forward_fbank( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = input.size(0) + feats = [] + feats_lens = [] + for i in range(batch_size): + waveform_length = input_lengths[i] + waveform = input[i][:waveform_length] + waveform = waveform * (1 << 15) + waveform = waveform.unsqueeze(0) + mat = kaldi.fbank( + waveform, + num_mel_bins=self.n_mels, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=self.dither, + energy_floor=0.0, + window_type=self.window, + sample_frequency=self.fs, + ) + + feat_length = mat.size(0) + feats.append(mat) + feats_lens.append(feat_length) + + feats_lens = torch.as_tensor(feats_lens) + feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) + return feats_pad, feats_lens + + def forward_lfr_cmvn( + self, input: torch.Tensor, input_lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + batch_size = input.size(0) + feats = [] + feats_lens = [] + for i in range(batch_size): + mat = input[i, : input_lengths[i], :] + if self.lfr_m != 1 or self.lfr_n != 1: + mat = apply_lfr(mat, self.lfr_m, self.lfr_n) + if self.cmvn is not None: + mat = apply_cmvn(mat, self.cmvn) + feat_length = mat.size(0) + feats.append(mat) + feats_lens.append(feat_length) + + feats_lens = torch.as_tensor(feats_lens) + feats_pad = pad_sequence(feats, batch_first=True, padding_value=0.0) + return feats_pad, feats_lens + + +class FunASRFeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a FunASR feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_ + utils.SequenceFeatureExtractor`] which contains most of the main + methods. Users should refer to this superclass for more information + regarding those methods. + + This class extracts mel-filter bank features from raw speech using a custom + numpy implementation of the `Short Time Fourier Transform` which should + match pytorch's `torch.stft` equivalent. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized + expressed in hertz (Hz). + hop_length (`int`, *optional*, defaults to 160): + Length of the overlapping windows for the STFT used to obtain the + Mel Frequency coefficients. + chunk_length (`int`, *optional*, defaults to 30): + The maximum number of chunks of `sampling_rate` samples used to + trim and pad longer or shorter audio sequences. + n_fft (`int`, *optional*, defaults to 400): + Size of the Fourier transform. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 0.0001 to add dithering with a normal distribution centered + around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range + of raw_speech). The value 0.0 means no dithering. + Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces + the high log_mel_fbank values for signals with hard-zero sections, + when VAD cutoff is present in the signal. + """ + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + hop_length=160, + chunk_length=30, + n_fft=400, + padding_value=0.0, + dither=0.0, + return_attention_mask=False, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.frontend_conf = kwargs.get("frontend_conf", {}) + self.n_fft = n_fft + self.hop_length = hop_length + self.chunk_length = chunk_length + self.n_samples = chunk_length * sampling_rate + self.nb_max_frames = self.n_samples // hop_length + self.sampling_rate = sampling_rate + self.dither = dither + + def extract_fbank( + self, data, data_len=None, data_type: str = "sound", frontend=None, **kwargs + ): + if isinstance(data, np.ndarray): + data = torch.from_numpy(data) + if len(data.shape) < 2: + data = data[None, :] # data: [batch, N] + data_len = [data.shape[1]] if data_len is None else data_len + elif isinstance(data, torch.Tensor): + if len(data.shape) < 2: + data = data[None, :] # data: [batch, N] + data_len = [data.shape[1]] if data_len is None else data_len + elif isinstance(data, (list, tuple)): + data_list, data_len = [], [] + for data_i in data: + if isinstance(data_i, np.ndarray): + data_i = torch.from_numpy(data_i) + data_list.append(data_i) + data_len.append(data_i.shape[0]) + data = pad_sequence(data_list, batch_first=True) + + data, data_len = frontend(data, data_len, **kwargs) + + if isinstance(data_len, (list, tuple)): + data_len = torch.tensor([data_len]) + return data.to(torch.float32), data_len.to(torch.int32) + + def __call__( + self, + raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], + truncation: bool = True, + pad_to_multiple_of: int | None = None, + return_tensors: str | TensorType | None = None, + return_attention_mask: bool | None = None, + padding: str | None = "max_length", + max_length: int | None = None, + sampling_rate: int | None = None, + do_normalize: bool | None = None, + device: str | None = "cpu", + return_token_timestamps: bool | None = None, + **kwargs, + ) -> BatchFeature: + is_batched = isinstance(raw_speech, (list, tuple)) and ( + isinstance(raw_speech[0], (np.ndarray, tuple, list)) + ) + + if is_batched: + raw_speech = [ + np.asarray([speech], dtype=np.float32).T for speech in raw_speech + ] + elif not is_batched and not isinstance(raw_speech, np.ndarray): + raw_speech = np.asarray(raw_speech, dtype=np.float32) + elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype( + np.float64 + ): + raw_speech = raw_speech.astype(np.float32) + + if not is_batched: + raw_speech = [np.asarray([raw_speech]).T] + + batched_speech = BatchFeature({"input_features": raw_speech}) + + padded_inputs = self.pad( + batched_speech, + padding=padding, + max_length=max_length if max_length else self.n_samples, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask or do_normalize, + ) + + input_features = padded_inputs.get("input_features").transpose(2, 0, 1) + + self.frontend = WavFrontend(**self.frontend_conf) + input_features, speech_lengths = self.extract_fbank( + input_features[0], + data_type=kwargs.get("data_type", "sound"), + frontend=self.frontend, + is_final=True, + ) + olens = 1 + (speech_lengths - 3 + 2 * 1) // 2 + olens = 1 + (olens - 3 + 2 * 1) // 2 + fake_token_len = (olens - 1) // 2 + 1 + if isinstance(input_features[0], list): + padded_inputs["input_features"] = [ + np.asarray(feature, dtype=np.float32) for feature in input_features + ] + + else: + padded_inputs["input_features"] = input_features + + if return_tensors is not None: + padded_inputs = padded_inputs.convert_to_tensors(return_tensors) + + padded_inputs["speech_lengths"] = speech_lengths + padded_inputs["fake_token_len"] = fake_token_len + + return padded_inputs + + +class FunASRProcessor(ProcessorMixin): + r""" + Constructs a FunASR processor which wraps a FunASR feature extractor and + a FunASR tokenizer into a single processor. + + [`FunASRProcessor`] offers all the functionalities of + [`FunASRFeatureExtractor`] and [`Qwen2Tokenizer`]. See the + [`~FunASRProcessor.__call__`] and [`~FunASRProcessor.decode`] for more + information. + + Args: + feature_extractor (`FunASRFeatureExtractor`): An instance of + [`FunASRFeatureExtractor`]. + The feature extractor is a required input. + tokenizer (`Qwen2Tokenizer`): + An instance of [`Qwen2Tokenizer`]. The tokenizer is a required + input. + """ + + feature_extractor_class = "FunASRFeatureExtractor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, + feature_extractor, + tokenizer, + audio_token="<|AUDIO|>", + ): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + self.audio_token = ( + tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token + ) + self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) + + def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): + return self.tokenizer.get_decoder_prompt_ids( + task=task, language=language, no_timestamps=no_timestamps + ) + + def __call__(self, *args, **kwargs): + """ + Forwards the `audio` argument to FunASRFeatureExtractor's + [`~FunASRFeatureExtractor.__call__`] and the `text` argument to + [`~Qwen2Tokenizer.__call__`]. Please refer to the docstring of the + above two methods for more information. + """ + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + audio = kwargs.pop("audio", None) + sampling_rate = kwargs.pop("sampling_rate", None) + text = kwargs.pop("text", None) + if len(args) > 0: + audio = args[0] + args = args[1:] + + if text is None: + raise ValueError("You need to specify `text` input to process.") + elif isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) + + if audio is not None: + # ensure we have as much audios as audio tokens + num_audio_tokens = sum(sample.count(self.audio_token) for sample in text) + num_audios = 1 if type(audio) is np.ndarray else len(audio) + if num_audio_tokens != num_audios: + raise ValueError( + f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" # noqa: E501 + ) + inputs = self.feature_extractor( + audio, *args, sampling_rate=sampling_rate, **kwargs + ) + + expanded_text = [] + for sample in text: + replace_str = [] + while self.audio_token in sample: + num_audio_tokens = inputs["fake_token_len"].item() + + expanded_audio_token = self.audio_token * num_audio_tokens + + replace_str.append(expanded_audio_token) + sample = sample.replace(self.audio_token, "", 1) + + while "" in sample: + sample = sample.replace("", replace_str.pop(0), 1) + expanded_text.append(sample) + text = expanded_text + + if text is not None: + encodings = self.tokenizer(text, **kwargs) + + if text is None: + return inputs + + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + + return inputs + + def get_prompt_ids(self, text: str, return_tensors="np"): + return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors) + + +AutoFeatureExtractor.register("FunASRFeatureExtractor", FunASRFeatureExtractor) +AutoProcessor.register("FunASRProcessor", FunASRProcessor)