diff --git a/docs/models/supported_models.md b/docs/models/supported_models.md index 534411c63..98d2a08d9 100644 --- a/docs/models/supported_models.md +++ b/docs/models/supported_models.md @@ -793,6 +793,7 @@ Speech2Text models trained specifically for Automatic Speech Recognition. | Architecture | Models | Example HF Models | [LoRA](../features/lora.md) | [PP](../serving/parallelism_scaling.md) | |--------------|--------|-------------------|----------------------|---------------------------| +| `FireRedASR2ForConditionalGeneration` | FireRedASR2 | `allendou/FireRedASR2-LLM-vllm`, etc. | | | | `FunASRForConditionalGeneration` | FunASR | `allendou/Fun-ASR-Nano-2512-vllm`, etc. | | | | `Gemma3nForConditionalGeneration` | Gemma3n | `google/gemma-3n-E2B-it`, `google/gemma-3n-E4B-it`, etc. | | | | `GlmAsrForConditionalGeneration` | GLM-ASR | `zai-org/GLM-ASR-Nano-2512` | ✅︎ | ✅︎ | diff --git a/requirements/common.txt b/requirements/common.txt index ec7ce5df9..9ee1b7151 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -57,3 +57,4 @@ opentelemetry-sdk >= 1.27.0 opentelemetry-api >= 1.27.0 opentelemetry-exporter-otlp >= 1.27.0 opentelemetry-semantic-conventions-ai >= 0.4.1 +kaldi-native-fbank >= 1.18.7 diff --git a/tests/models/registry.py b/tests/models/registry.py index 08f1a14d7..88017805f 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -743,6 +743,9 @@ _MULTIMODAL_EXAMPLE_MODELS = { "baidu/ERNIE-4.5-VL-28B-A3B-PT", trust_remote_code=True, ), + "FireRedASR2ForConditionalGeneration": _HfExamplesInfo( + "allendou/FireRedASR2-LLM-vllm", + ), "FunASRForConditionalGeneration": _HfExamplesInfo( "allendou/Fun-ASR-Nano-2512-vllm", ), diff --git a/vllm/model_executor/models/fireredasr2.py b/vllm/model_executor/models/fireredasr2.py new file mode 100644 index 000000000..f0d3e124c --- /dev/null +++ b/vllm/model_executor/models/fireredasr2.py @@ -0,0 +1,829 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Annotated, Literal, cast + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import ( + BatchFeature, + Qwen2Config, +) + +from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.config.multimodal import BaseDummyOptions +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import _ACTIVATION_REGISTRY +from vllm.model_executor.layers.linear import ( + ReplicatedLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.models.whisper_utils import ( + ISO639_1_SUPPORTED_LANGS, +) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import MultiModalDataItems, MultiModalDataParser +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseMultiModalProcessor, + BaseProcessingInfo, + PromptReplacement, + PromptUpdate, + PromptUpdateDetails, +) +from vllm.transformers_utils.processor import cached_processor_from_config +from vllm.transformers_utils.processors.fireredasr2_processor import ( + FireRedASR2FeatureExtractor, +) +from vllm.utils.tensor_schema import TensorSchema, TensorShape + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsTranscription, + _require_is_multimodal, +) +from .qwen2 import Qwen2ForCausalLM +from .utils import ( + AutoWeightsLoader, + WeightsMapper, + _merge_multimodal_embeddings, + maybe_prefix, +) + +logger = init_logger(__name__) + + +class FireRedASR2AudioInputs(TensorSchema): + """ + Dimensions: + - b: Batch size + - nmb: Number of mel bins + - t: Time frames (M) + """ + + input_features: Annotated[ + list[torch.Tensor] | None, + TensorShape("b", "nmb", "t"), + ] + speech_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] + fake_token_lengths: Annotated[ + list[torch.Tensor] | None, + TensorShape("b"), + ] + + +class Swish(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * torch.sigmoid(x) + + +class Conv2dSubsampling(nn.Module): + def __init__(self, idim: int, d_model: int, out_channels: int = 32): + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d(1, out_channels, 3, 2), + nn.ReLU(), + nn.Conv2d(out_channels, out_channels, 3, 2), + nn.ReLU(), + ) + subsample_idim = ((idim - 1) // 2 - 1) // 2 + self.out = ReplicatedLinear( + input_size=out_channels * subsample_idim, + output_size=d_model, + bias=True, + ) + + self.subsampling = 4 + left_context = right_context = 3 # both exclude currect frame + self.context = left_context + 1 + right_context # 7 + + def forward( + self, x: torch.Tensor, x_mask: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + x = x.unsqueeze(1) + x = self.conv(x) + N, C, T, D = x.size() + x, _ = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D)) + mask = x_mask[:, :, :-2:2][:, :, :-2:2] + input_lengths = mask[:, -1, :].sum(dim=-1) + return x, input_lengths, mask + + +class RelPositionalEncoding(nn.Module): + def __init__(self, d_model: int, max_len: int = 5000): + super().__init__() + pe_positive = torch.zeros(max_len, d_model, requires_grad=False) + pe_negative = torch.zeros(max_len, d_model, requires_grad=False) + position = torch.arange(0, max_len).unsqueeze(1).float() + div_term = torch.exp( + torch.arange(0, d_model, 2).float() + * -(torch.log(torch.tensor(10000.0)).item() / d_model) + ) + pe_positive[:, 0::2] = torch.sin(position * div_term) + pe_positive[:, 1::2] = torch.cos(position * div_term) + pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) + pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) + + pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) + pe_negative = pe_negative[1:].unsqueeze(0) + self.pe = torch.cat([pe_positive, pe_negative], dim=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Tmax = 2 * max_len - 1 + Tmax, T = self.pe.size(1), x.size(1) + pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach() + return pos_emb + + +class ConformerFeedForward(nn.Module): + def __init__(self, d_model: int): + super().__init__() + self.pre_layer_norm = nn.LayerNorm(d_model) + self.linear_expand = ReplicatedLinear( + input_size=d_model, + output_size=d_model * 4, + bias=True, + ) + self.nonlinear = Swish() + self.linear_project = ReplicatedLinear( + input_size=d_model * 4, + output_size=d_model, + bias=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.pre_layer_norm(x) + x, _ = self.linear_expand(x) + x = self.nonlinear(x) + x, _ = self.linear_project(x) + output = x + residual + return output + + +class EncoderMultiHeadAttention(nn.Module): + def __init__(self, n_head: int, d_model: int): + super().__init__() + assert d_model % n_head == 0 + self.n_head = n_head + self.d_k = d_model // n_head + self.d_v = self.d_k + + self.w_qs = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_k, bias=False + ) + self.w_ks = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_k, bias=False + ) + self.w_vs = ReplicatedLinear( + input_size=d_model, output_size=n_head * self.d_v, bias=False + ) + + self.layer_norm_q = nn.LayerNorm(d_model) + self.layer_norm_k = nn.LayerNorm(d_model) + self.layer_norm_v = nn.LayerNorm(d_model) + + self.fc = ReplicatedLinear( + input_size=n_head * self.d_v, output_size=d_model, bias=False + ) + + def forward_qkv( + self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + d_k, d_v, n_head = self.d_k, self.d_v, self.n_head + sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1) + + q = self.layer_norm_q(q) + k = self.layer_norm_k(k) + v = self.layer_norm_v(v) + + q = self.w_qs(q)[0].view(sz_b, len_q, n_head, d_k) + k = self.w_ks(k)[0].view(sz_b, len_k, n_head, d_k) + v = self.w_vs(v)[0].view(sz_b, len_v, n_head, d_v) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + return q, k, v + + def forward_output( + self, output: torch.Tensor, residual: torch.Tensor, sz_b: int, len_q: int + ) -> torch.Tensor: + output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1) + fc_out, _ = self.fc(output) + output = fc_out + output = output + residual + return output + + def forward_attention( + self, attn: torch.Tensor, v: torch.Tensor, mask: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor]: + if mask is not None: + mask = mask.unsqueeze(1) + mask = mask.eq(0) + attn = attn.masked_fill(mask, -float("inf")) + attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0) + else: + attn = torch.softmax(attn, dim=-1) + + d_attn = attn + output = torch.matmul(d_attn, v) + + return output, attn + + +class RelPosMultiHeadAttention(EncoderMultiHeadAttention): + def __init__(self, n_head: int, d_model: int): + super().__init__(n_head, d_model) + d_k = d_model // n_head + self.scale = 1.0 / (d_k**0.5) + self.linear_pos = ReplicatedLinear( + input_size=d_model, output_size=n_head * d_k, bias=False + ) + self.pos_bias_u = nn.Parameter(torch.empty([n_head, d_k])) + self.pos_bias_v = nn.Parameter(torch.empty([n_head, d_k])) + + def _rel_shift(self, x): + N, H, T1, T2 = x.size() + zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype) + x_padded = torch.cat([zero_pad, x], dim=-1) + + x_padded = x_padded.view(N, H, T2 + 1, T1) + x = x_padded[:, :, 1:].view_as(x) + x = x[:, :, :, : x.size(-1) // 2 + 1] + return x + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + pos_emb: torch.Tensor, + mask: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + sz_b, len_q = q.size(0), q.size(1) + + residual = q + q, k, v = self.forward_qkv(q, k, v) + + q = q.transpose(1, 2) + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb)[0].view(n_batch_pos, -1, self.n_head, self.d_k) + p = p.transpose(1, 2) + + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self._rel_shift(matrix_bd) + + attn_scores = matrix_ac + matrix_bd + attn_scores.mul_(self.scale) + + output, attn = self.forward_attention(attn_scores, v, mask=mask) + + output = self.forward_output(output, residual, sz_b, len_q) + return output, attn + + +class ConformerConvolution(nn.Module): + def __init__(self, d_model: int, kernel_size: int = 33): + super().__init__() + assert kernel_size % 2 == 1 + self.pre_layer_norm = nn.LayerNorm(d_model) + self.pointwise_conv1 = nn.Conv1d( + d_model, d_model * 4, kernel_size=1, bias=False + ) + self.padding = (kernel_size - 1) // 2 + self.depthwise_conv = nn.Conv1d( + d_model * 2, + d_model * 2, + kernel_size, + stride=1, + padding=self.padding, + groups=d_model * 2, + bias=False, + ) + self.batch_norm = nn.LayerNorm(d_model * 2) + self.swish = Swish() + self.pointwise_conv2 = nn.Conv1d( + d_model * 2, d_model, kernel_size=1, bias=False + ) + + def forward( + self, x: torch.Tensor, mask: torch.Tensor | None = None + ) -> torch.Tensor: + residual = x + out = self.pre_layer_norm(x) + out = out.transpose(1, 2) + if mask is not None: + out.masked_fill_(mask.ne(1), 0.0) + out = self.pointwise_conv1(out) + out = F.glu(out, dim=1) + out = self.depthwise_conv(out) + + out = out.transpose(1, 2) + out = self.swish(self.batch_norm(out)) + out = out.transpose(1, 2) + + out = self.pointwise_conv2(out) + if mask is not None: + out.masked_fill_(mask.ne(1), 0.0) + out = out.transpose(1, 2) + return out + residual + + +class RelPosEmbConformerBlock(nn.Module): + def __init__(self, d_model, n_head, kernel_size=33): + super().__init__() + self.ffn1 = ConformerFeedForward(d_model) + self.mhsa = RelPosMultiHeadAttention(n_head, d_model) + self.conv = ConformerConvolution(d_model, kernel_size) + self.ffn2 = ConformerFeedForward(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x: torch.Tensor, + pos_emb: torch.Tensor, + slf_attn_mask: torch.Tensor | None = None, + pad_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + out = 0.5 * x + 0.5 * self.ffn1(x) + out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0] + out = self.conv(out, pad_mask) + out = 0.5 * out + 0.5 * self.ffn2(out) + out = self.layer_norm(out) + return out + + +class ConformerEncoder(nn.Module): + def __init__( + self, + idim: int, + n_layers_enc: int, + n_head: int, + d_model: int, + kernel_size: int = 33, + pe_maxlen: int = 5000, + ): + super().__init__() + self.odim = d_model + + self.input_preprocessor = Conv2dSubsampling(idim, d_model) + self.positional_encoding = RelPositionalEncoding(d_model) + + self.layer_stack = nn.ModuleList() + for _ in range(n_layers_enc): + block = RelPosEmbConformerBlock(d_model, n_head, kernel_size) + self.layer_stack.append(block) + + def forward( + self, padded_input: torch.Tensor, input_lengths: torch.Tensor, pad: bool = True + ): + if pad: + padded_input = F.pad( + padded_input, + (0, 0, 0, self.input_preprocessor.context - 1), + "constant", + 0.0, + ) + src_mask = self.padding_position_is_0(padded_input, input_lengths) + + embed_output, input_lengths, src_mask = self.input_preprocessor( + padded_input, src_mask + ) + enc_output = embed_output + + pos_emb = self.positional_encoding(embed_output) + + enc_outputs = [] + for enc_layer in self.layer_stack: + enc_output = enc_layer( + enc_output, pos_emb, slf_attn_mask=src_mask, pad_mask=src_mask + ) + enc_outputs.append(enc_output) + + return enc_output, input_lengths, src_mask + + def padding_position_is_0( + self, padded_input: torch.Tensor, input_lengths: torch.Tensor + ) -> torch.Tensor: + N, T = padded_input.size()[:2] + mask = torch.ones((N, T)).to(padded_input.device) + for i in range(N): + mask[i, input_lengths[i] :] = 0 + mask = mask.unsqueeze(dim=1) + return mask.to(torch.uint8) + + +class FireRedASR2Adapter(nn.Module): + def __init__(self, encoder_dim: int, llm_dim: int, downsample_rate: int = 2): + super().__init__() + self.ds = downsample_rate + self.linear1 = ReplicatedLinear( + input_size=encoder_dim * downsample_rate, + output_size=llm_dim, + bias=True, + ) + self.relu = _ACTIVATION_REGISTRY["relu"] + self.linear2 = ReplicatedLinear( + input_size=llm_dim, + output_size=llm_dim, + bias=True, + ) + + def forward(self, x, x_lens): + batch_size, seq_len, feat_dim = x.size() + num_frames_to_discard = seq_len % self.ds + if num_frames_to_discard > 0: + x = x[:, :-num_frames_to_discard, :] + seq_len = x.size(1) + + x = x.contiguous() + x = x.view(batch_size, seq_len // self.ds, feat_dim * self.ds) + + x, _ = self.linear1(x) + x = self.relu(x) + x, _ = self.linear2(x) + + new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds + return x, new_x_lens + + +class FireRedASR2Encoder(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + ): + super().__init__() + self.audio_encoder = ConformerEncoder( + **vllm_config.model_config.hf_config.audio_encoder_conf + ) + + +class FireRedASR2Model(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = FireRedASR2Encoder( + vllm_config=vllm_config, + ) + encoder_dim = self.encoder.audio_encoder.odim + llm_dim = vllm_config.model_config.hf_config.hidden_size + self.encoder_projector = FireRedASR2Adapter( + encoder_dim, + llm_dim, + vllm_config.model_config.hf_config.encoder_downsample_rate, + ) + + self.decoder = Qwen2ForCausalLM( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "decoder") + ) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> torch.Tensor: + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def get_encoder_outputs( + self, + speech: torch.Tensor | list[torch.Tensor] | None, + speech_lengths: torch.Tensor | list[torch.Tensor] | None, + ) -> torch.Tensor | None: + encoder_outs, enc_lengths, enc_mask = self.encoder.audio_encoder( + speech, speech_lengths + ) + speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths) + return speech_features + + +class FireRedASR2ProcessingInfo(BaseProcessingInfo): + def get_hf_config(self) -> Qwen2Config: + return self.ctx.get_hf_config(Qwen2Config) + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": 1} + + def get_feature_extractor(self, **kwargs: object) -> FireRedASR2FeatureExtractor: + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor # type: ignore + assert isinstance(feature_extractor, FireRedASR2FeatureExtractor) + return feature_extractor + + def get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.get_feature_extractor() + return MultiModalDataParser( + target_sr=feature_extractor.sampling_rate, + target_channels=self.get_target_channels(), + ) + + def get_target_channels(self) -> int: + return 1 + + +class FireRedASR2DummyInputsBuilder(BaseDummyInputsBuilder[FireRedASR2ProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + return "<|AUDIO|>" * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options: Mapping[str, BaseDummyOptions], + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.chunk_length * sampling_rate + num_audios = mm_counts.get("audio", 0) + + audio_overrides = mm_options.get("audio") + + ret = { + "audio": self._get_dummy_audios( + length=audio_len, num_audios=num_audios, overrides=audio_overrides + ) + } + return ret + + +class FireRedASR2MultiModalProcessor( + BaseMultiModalProcessor[FireRedASR2ProcessingInfo] +): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs: BatchFeature, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + speech_lengths=MultiModalFieldConfig.batched("audio"), + fake_token_lengths=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + + audio_token = getattr(processor, "audio_token", "<|AUDIO|>") + + audio_token_id = vocab[audio_token] + + out_mm_data = out_mm_kwargs.get_data() + + fake_token_lengths = out_mm_data.get("fake_token_lengths") + + if fake_token_lengths is None: + audio_output_lengths = [] + else: + assert isinstance(fake_token_lengths, torch.Tensor) + + audio_output_lengths = fake_token_lengths.tolist() + + def get_replacement_fireredasr2_audio(item_idx: int): + num_features = audio_output_lengths[item_idx] + + audio_tokens = [audio_token_id] * int(num_features) + + return PromptUpdateDetails.select_token_id( + audio_tokens, + embed_token_id=audio_token_id, + ) + + return [ + PromptReplacement( + modality="audio", + target=[audio_token_id], + replacement=get_replacement_fireredasr2_audio, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + FireRedASR2MultiModalProcessor, + info=FireRedASR2ProcessingInfo, + dummy_inputs=FireRedASR2DummyInputsBuilder, +) +class FireRedASR2ForConditionalGeneration( + nn.Module, SupportsTranscription, SupportsMultiModal +): + packed_modules_mapping = { + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={ + "llm.": "model.decoder.", + "encoder.": "model.encoder.audio_encoder.", + "encoder_projector.": "model.encoder_projector.", + "net.0": "pre_layer_norm", + "net.1": "linear_expand", + "net.4": "linear_project", + } + ) + + supports_transcription_only = True + supports_segment_timestamp = True + supported_languages = ISO639_1_SUPPORTED_LANGS + + @classmethod + def validate_language(cls, language: str | None) -> str | None: + if language is None: + # TODO language should be optional and can be guessed. + # For now we default to en. See + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/generation_whisper.py#L1520 + logger.warning( + "Defaulting to language='en'. If you wish to transcribe " + "audio in a different language, pass the `language` field " + "in the TranscriptionRequest." + ) + language = "en" + return super().validate_language(language) + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + if language is None: + raise ValueError( + "Language must be specified when creating the fireredasr2 prompt" + ) + + prompt_str = "<|im_start|>user\n<|AUDIO|>请转写音频为文字<|im_end|>\n<|im_start|>assistant\n" # noqa: E501 + prompt = { + "prompt": prompt_str, + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + } + return cast(PromptType, prompt) + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + processor = cached_processor_from_config(model_config) + + return SpeechToTextConfig( + max_audio_clip_s=processor.feature_extractor.chunk_length, + sample_rate=processor.feature_extractor.sampling_rate, + ) + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> int | None: + processor = cached_processor_from_config(model_config) + hop_length = processor.feature_extractor.hop_length + assert hop_length is not None + return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = FireRedASR2Model( + vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model"), + ) + logit_scale = getattr(config, "logit_scale", 1.0) + + self.logits_processor = LogitsProcessor(config.vocab_size, scale=logit_scale) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + decoder_outputs = self.model( + input_ids=input_ids, + positions=positions, + inputs_embeds=inputs_embeds, + ) + return decoder_outputs + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + audio_input = self._parse_and_validate_audio_input(**kwargs) + + speech = audio_input["input_features"] + speech_lengths = audio_input["speech_lengths"].to(torch.int32) + enc_output = self.model.get_encoder_outputs( + speech=speech, speech_lengths=speech_lengths + ) + + return enc_output + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: MultiModalEmbeddings | None = None, + *, + is_multimodal: torch.Tensor | None = None, + handle_oov_mm_token: bool = False, + ) -> torch.Tensor: + inputs_embeds = self.model.decoder.embed_input_ids(input_ids) + + ret = _merge_multimodal_embeddings( + inputs_embeds=inputs_embeds, + multimodal_embeddings=multimodal_embeddings, + is_multimodal=_require_is_multimodal(is_multimodal), + ) + return ret + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> FireRedASR2AudioInputs: + input_features = kwargs.pop("input_features", None) + speech_lengths = kwargs.pop("speech_lengths", None) + fake_token_lengths = kwargs.pop("fake_token_lengths", None) + + return FireRedASR2AudioInputs( + input_features=input_features, + speech_lengths=speech_lengths, + fake_token_lengths=fake_token_lengths, + ) + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.model.decoder.lm_head, hidden_states) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + loader = AutoWeightsLoader( + self, skip_prefixes=["model.encoder.audio_encoder.positional_encoding.pe"] + ) + + return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 7f6b7e300..1e5accaf3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -341,6 +341,10 @@ _MULTIMODAL_MODELS = { "ernie45_vl", "Ernie4_5_VLMoeForConditionalGeneration", ), + "FireRedASR2ForConditionalGeneration": ( + "fireredasr2", + "FireRedASR2ForConditionalGeneration", + ), "FunASRForConditionalGeneration": ("funasr", "FunASRForConditionalGeneration"), # noqa: E501 "FunAudioChatForConditionalGeneration": ( "funaudiochat", diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index d726fd39a..0660a62ea 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -10,6 +10,9 @@ reasons: from vllm.transformers_utils.processors.bagel import BagelProcessor from vllm.transformers_utils.processors.deepseek_vl2 import DeepseekVLV2Processor +from vllm.transformers_utils.processors.fireredasr2_processor import ( + FireRedASR2Processor, +) from vllm.transformers_utils.processors.funasr_processor import FunASRProcessor from vllm.transformers_utils.processors.hunyuan_vl import HunYuanVLProcessor from vllm.transformers_utils.processors.hunyuan_vl_image import HunYuanVLImageProcessor @@ -19,6 +22,7 @@ from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor __all__ = [ "BagelProcessor", "DeepseekVLV2Processor", + "FireRedASR2Processor", "FunASRProcessor", "HunYuanVLProcessor", "HunYuanVLImageProcessor", diff --git a/vllm/transformers_utils/processors/fireredasr2_processor.py b/vllm/transformers_utils/processors/fireredasr2_processor.py new file mode 100644 index 000000000..67c74ab15 --- /dev/null +++ b/vllm/transformers_utils/processors/fireredasr2_processor.py @@ -0,0 +1,341 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +import kaldi_native_fbank as knf +import numpy as np +import torch +import torch.nn.functional as F +from transformers import ( + AutoFeatureExtractor, + AutoProcessor, + BatchFeature, +) +from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor +from transformers.processing_utils import ProcessorMixin +from transformers.utils import TensorType + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class CMVN: + def __init__(self, dim, means, inverse_std_variences): + self.dim, self.means, self.inverse_std_variences = ( + dim, + np.array(means), + np.array(inverse_std_variences), + ) + + def __call__(self, x): + assert x.shape[-1] == self.dim, "CMVN dim mismatch" + out = x - self.means + out = out * self.inverse_std_variences + return out + + +class KaldifeatFbank: + def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10, dither=1.0): + self.dither = dither + opts = knf.FbankOptions() + opts.frame_opts.dither = dither + opts.mel_opts.num_bins = num_mel_bins + opts.frame_opts.snip_edges = True + opts.mel_opts.debug_mel = False + self.opts = opts + + def __call__(self, sample_rate, wav_np, is_train=False): + dither = self.dither if is_train else 0.0 + self.opts.frame_opts.dither = dither + fbank = knf.OnlineFbank(self.opts) + + fbank.accept_waveform(sample_rate, wav_np.tolist()) + feat = [] + for i in range(fbank.num_frames_ready): + feat.append(fbank.get_frame(i)) + if len(feat) == 0: + print("Check data, len(feat) == 0", wav_np, flush=True) + return np.zeros((0, self.opts.mel_opts.num_bins)) + feat = np.vstack(feat) + return feat + + +class FireRedASR2FeatureExtractor(SequenceFeatureExtractor): + r""" + Constructs a FireRedASR2 feature extractor. + + This feature extractor inherits from [`~feature_extraction_sequence_ + utils.SequenceFeatureExtractor`] which contains most of the main + methods. Users should refer to this superclass for more information + regarding those methods. + + This class extracts mel-filter bank features from raw speech using a custom + numpy implementation of the `Short Time Fourier Transform` which should + match pytorch's `torch.stft` equivalent. + + Args: + feature_size (`int`, *optional*, defaults to 80): + The feature dimension of the extracted features. + sampling_rate (`int`, *optional*, defaults to 16000): + The sampling rate at which the audio files should be digitalized + expressed in hertz (Hz). + chunk_length (`int`, *optional*, defaults to 30): + The maximum number of chunks of `sampling_rate` samples used to + trim and pad longer or shorter audio sequences. + padding_value (`float`, *optional*, defaults to 0.0): + Padding value used to pad the audio. Should correspond to silences. + dither (`float`, *optional*, defaults to 0.0): + Adds dithering. In other words, adds a small Gaussian noise to each frame. + E.g. use 0.0001 to add dithering with a normal distribution centered + around 0.0 with standard deviation 0.0001 (assuming [-1,+1] range + of raw_speech). The value 0.0 means no dithering. + Dithering has similar effect as `spectrogram(mel_floor=...)`. It reduces + the high log_mel_fbank values for signals with hard-zero sections, + when VAD cutoff is present in the signal. + """ + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=80, + sampling_rate=16000, + chunk_length=30, + padding_value=0.0, + return_attention_mask=False, + dim=80, + means=None, + inverse_std_variences=None, + num_mel_bins=80, + frame_length=25, + frame_shift=10, + dither=0.0, + max_length=3000, + downsample_rate=2, + left_context=3, + right_context=3, + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + return_attention_mask=return_attention_mask, + **kwargs, + ) + self.chunk_length = chunk_length + self.max_length = max_length + self.dim = dim + self.means = means + self.inverse_std_variences = inverse_std_variences + self.num_mel_bins = num_mel_bins + self.frame_length = frame_length + self.frame_shift = frame_shift + self.dither = dither + self.sampling_rate = sampling_rate + self.downsample_rate = downsample_rate + self.context = left_context + 1 + right_context + + def __call__( + self, + raw_speech: np.ndarray | list[float] | list[np.ndarray] | list[list[float]], + truncation: bool = True, + pad_to_multiple_of: int | None = None, + return_tensors: str | TensorType | None = None, + return_attention_mask: bool | None = None, + padding: str | None = "max_length", + max_length: int | None = None, + sampling_rate: int | None = None, + do_normalize: bool | None = None, + **kwargs, + ) -> BatchFeature: + if sampling_rate != self.sampling_rate: + raise ValueError( + f"The model corresponding to this feature extractor: " + f"{self.__class__.__name__} was trained using a sampling " + f"rate of {self.sampling_rate}. Please make sure that the " + f"provided `raw_speech` input was sampled with " + f"{self.sampling_rate} and not {sampling_rate}." + ) + + def padding_position_is_0(padded_input, input_lengths): + N, T = padded_input.size()[:2] + mask = torch.ones((N, T)).to(padded_input.device) + for i in range(N): + mask[i, input_lengths[i] :] = 0 + mask = mask.unsqueeze(dim=1) + return mask.to(torch.uint8) + + # initialize the CMVN and Fbank objects + self.cmvn = CMVN(self.dim, self.means, self.inverse_std_variences) + self.fbank = KaldifeatFbank( + num_mel_bins=self.num_mel_bins, + frame_length=self.frame_length, + frame_shift=self.frame_shift, + dither=self.dither, + ) + + feats = [] + speech_lengths = [] + fake_token_lengths = [] + for speech in raw_speech: + """ + We must multiply by 32768 here because FireRedASR2 loads audio data + using kaldiio.load_mat, while vLLM loads audio data using librosa. + """ + speech = speech * 32768 + fbank = self.fbank(sampling_rate, speech) + fbank = self.cmvn(fbank) + fbank = torch.from_numpy(fbank).float() + length = fbank.size(0) + feats.append(fbank) + speech_lengths.append(length) + padded_input2 = fbank + padded_input2 = F.pad( + padded_input2, (0, 0, 0, self.context - 1), "constant", 0.0 + ) + src_mask = padding_position_is_0( + padded_input2[None, :, :], torch.tensor([length], dtype=torch.int32) + ) + x_mask = src_mask + mask = x_mask[:, :, :-2:2][:, :, :-2:2] + input_lengths = mask[:, -1, :].sum(dim=-1) + input_lengths = input_lengths // self.downsample_rate + fake_token_len = torch.clamp(input_lengths, min=1) + fake_token_lengths.append(fake_token_len) + + feats = torch.stack(feats, dim=0) + batched_speech = self.pad( + BatchFeature({"input_features": feats}), + padding=padding, + max_length=max_length if max_length else self.max_length, + truncation=truncation, + pad_to_multiple_of=pad_to_multiple_of, + return_attention_mask=return_attention_mask or do_normalize, + ) + + if return_tensors is not None: + batched_speech = batched_speech.convert_to_tensors(return_tensors) + + batched_speech["speech_lengths"] = torch.tensor(speech_lengths) + batched_speech["fake_token_lengths"] = torch.concat(fake_token_lengths) + return batched_speech + + +class FireRedASR2Processor(ProcessorMixin): + r""" + Constructs a FireRedASR2 processor which wraps a FireRedASR2 feature extractor and + a FireRedASR2 tokenizer into a single processor. + + [`FireRedASR2Processor`] offers all the functionalities of + [`FireRedASR2FeatureExtractor`] and [`Qwen2Tokenizer`]. See the + [`~FireRedASR2Processor.__call__`] and [`~FireRedASR2Processor.decode`] for more + information. + + Args: + feature_extractor (`FireRedASR2FeatureExtractor`): An instance of + [`FireRedASR2FeatureExtractor`]. + The feature extractor is a required input. + tokenizer (`Qwen2Tokenizer`): + An instance of [`Qwen2Tokenizer`]. The tokenizer is a required + input. + """ + + feature_extractor_class = "FireRedASR2FeatureExtractor" + tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast") + + def __init__( + self, + feature_extractor, + tokenizer, + audio_token="<|AUDIO|>", + ): + super().__init__(feature_extractor, tokenizer) + self.current_processor = self.feature_extractor + self._in_target_context_manager = False + self.audio_token = ( + tokenizer.audio_token if hasattr(tokenizer, "audio_token") else audio_token + ) + self.audio_token_id = tokenizer.convert_tokens_to_ids(self.audio_token) + + def get_decoder_prompt_ids(self, task=None, language=None, no_timestamps=True): + return self.tokenizer.get_decoder_prompt_ids( + task=task, language=language, no_timestamps=no_timestamps + ) + + def __call__(self, *args, **kwargs): + """ + Forwards the `audio` argument to FireRedASR2FeatureExtractor's + [`~FireRedASR2FeatureExtractor.__call__`] and the `text` argument to + [`~Qwen2Tokenizer.__call__`]. Please refer to the docstring of the + above two methods for more information. + """ + if self._in_target_context_manager: + return self.current_processor(*args, **kwargs) + + audio = kwargs.pop("audio", None) + sampling_rate = kwargs.pop("sampling_rate", None) + text = kwargs.pop("text", None) + if len(args) > 0: + audio = args[0] + args = args[1:] + + if text is None: + raise ValueError("You need to specify `text` input to process.") + elif isinstance(text, str): + text = [text] + elif not isinstance(text, list) and not isinstance(text[0], str): + raise ValueError( + "Invalid input text. Please provide a string, or a list of strings" + ) + + if audio is not None: + # ensure we have as much audios as audio tokens + num_audio_tokens = sum(sample.count(self.audio_token) for sample in text) + num_audios = 1 if type(audio) is np.ndarray else len(audio) + if num_audio_tokens != num_audios: + raise ValueError( + f"Found {num_audio_tokens} {self.audio_token} token{'s' if num_audio_tokens > 1 else ''} in provided text but received {num_audios} audio{'s' if num_audios > 1 else ''}" # noqa: E501 + ) + inputs = self.feature_extractor( + audio, *args, sampling_rate=sampling_rate, **kwargs + ) + + expanded_text = [] + for sample in text: + replace_str = [] + while self.audio_token in sample: + num_audio_tokens = int(inputs["fake_token_lengths"].item()) + + expanded_audio_token = self.audio_token * num_audio_tokens + + replace_str.append(expanded_audio_token) + sample = sample.replace(self.audio_token, "", 1) + + while "" in sample: + sample = sample.replace("", replace_str.pop(0), 1) + expanded_text.append(sample) + text = expanded_text + + if text is not None: + encodings = self.tokenizer(text, **kwargs) + + if text is None: + return inputs + + elif audio is None: + return encodings + else: + inputs["labels"] = encodings["input_ids"] + + return inputs + + def get_prompt_ids(self, text: str, return_tensors="np"): + return self.tokenizer.get_prompt_ids(text, return_tensors=return_tensors) + + +AutoFeatureExtractor.register( + "FireRedASR2FeatureExtractor", FireRedASR2FeatureExtractor +) +AutoProcessor.register("FireRedASR2Processor", FireRedASR2Processor)