diff --git a/examples/offline_inference/audio_language.py b/examples/offline_inference/audio_language.py index f7292c468..780ddb90e 100755 --- a/examples/offline_inference/audio_language.py +++ b/examples/offline_inference/audio_language.py @@ -70,6 +70,29 @@ def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData: ) +# CohereASR +def run_cohere_asr(question: str, audio_count: int) -> ModelRequestData: + assert audio_count == 1, "CohereASR only support single audio input per prompt" + # TODO (ekagra): add HF ckpt after asr release + model_name = "/host/engines/vllm/audio/2b-release" + + prompt = ( + "<|startofcontext|><|startoftranscript|>" + "<|emo:undefined|><|en|><|en|><|pnc|><|noitn|>" + "<|notimestamp|><|nodiarize|>" + ) + engine_args = EngineArgs( + model=model_name, + limit_mm_per_prompt={"audio": audio_count}, + trust_remote_code=True, + ) + + return ModelRequestData( + engine_args=engine_args, + prompt=prompt, + ) + + # MusicFlamingo def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData: model_name = "nvidia/music-flamingo-2601-hf" @@ -508,14 +531,15 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData: model_example_map = { "audioflamingo3": run_audioflamingo3, - "musicflamingo": run_musicflamingo, + "cohere_asr": run_cohere_asr, + "funaudiochat": run_funaudiochat, "gemma3n": run_gemma3n, "glmasr": run_glmasr, - "funaudiochat": run_funaudiochat, "granite_speech": run_granite_speech, "kimi_audio": run_kimi_audio, "midashenglm": run_midashenglm, "minicpmo": run_minicpmo, + "musicflamingo": run_musicflamingo, "phi4_mm": run_phi4mm, "qwen2_audio": run_qwen2_audio, "qwen2_5_omni": run_qwen2_5_omni, diff --git a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py index 2725a1295..c4c7b8b7f 100644 --- a/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py +++ b/tests/entrypoints/openai/correctness/test_transcription_api_correctness.py @@ -19,8 +19,10 @@ import soundfile import torch from datasets import load_dataset from evaluate import load -from transformers import AutoTokenizer +from vllm.tokenizers import get_tokenizer + +from ....models.registry import HF_EXAMPLE_MODELS from ....utils import RemoteOpenAIServer @@ -64,8 +66,12 @@ async def bound_transcribe(sem, client, tokenizer, audio, reference): async def process_dataset(model, client, data, concurrent_request): sem = asyncio.Semaphore(concurrent_request) - # Load tokenizer once outside the loop - tokenizer = AutoTokenizer.from_pretrained(model) + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + tokenizer = get_tokenizer( + model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + ) # Warmup call as the first `librosa.load` server-side is quite slow. audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"] @@ -144,20 +150,35 @@ def run_evaluation( # alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo".. -@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"]) +# NOTE: Expected WER measured with equivalent hf.transformers args: +# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. +@pytest.mark.parametrize( + "model_config", + [ + ("openai/whisper-large-v3", 12.744980), + # TODO (ekagra): add HF ckpt after asr release + # ("/host/engines/vllm/audio/2b-release", 11.73), + ], +) # Original dataset is 20GB+ in size, hence we use a pre-filtered slice. @pytest.mark.parametrize( "dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"] ) -# NOTE: Expected WER measured with equivalent hf.transformers args: -# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered. -@pytest.mark.parametrize("expected_wer", [12.744980]) def test_wer_correctness( - model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None + model_config, dataset_repo, n_examples=-1, max_concurrent_request=None ): + model_name, expected_wer = model_config + model_info = HF_EXAMPLE_MODELS.find_hf_info(model_name) # TODO refactor to use `ASRDataset` + server_args = [ + "--enforce-eager", + f"--tokenizer_mode={model_info.tokenizer_mode}", + ] + if model_info.trust_remote_code: + server_args.append("--trust-remote-code") with RemoteOpenAIServer( - model_name, ["--enforce-eager"], max_wait_seconds=480 + model_name, + server_args, ) as remote_server: dataset = load_hf_dataset(dataset_repo) @@ -167,7 +188,14 @@ def test_wer_correctness( client = remote_server.get_async_client() wer = run_evaluation( - model_name, client, dataset, max_concurrent_request, n_examples + model_name, + client, + dataset, + max_concurrent_request, + n_examples, ) + + print(f"Expected WER: {expected_wer}, Actual WER: {wer}") + if expected_wer: torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2) diff --git a/tests/models/registry.py b/tests/models/registry.py index 7f806064f..fe5585f85 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1116,6 +1116,11 @@ _MULTIMODAL_EXAMPLE_MODELS = { tokenizer_mode="mistral", ), # [Encoder-decoder] + "CohereASRForConditionalGeneration": _HfExamplesInfo( + "/host/engines/vllm/audio/2b-release", + trust_remote_code=True, + is_available_online=False, # TODO (ekagra): revert after asr release + ), "NemotronParseForConditionalGeneration": _HfExamplesInfo( "nvidia/NVIDIA-Nemotron-Parse-v1.1", trust_remote_code=True ), diff --git a/vllm/benchmarks/datasets.py b/vllm/benchmarks/datasets.py index 7e7e56dc6..edd84403f 100644 --- a/vllm/benchmarks/datasets.py +++ b/vllm/benchmarks/datasets.py @@ -3157,7 +3157,7 @@ class ASRDataset(HuggingFaceDataset): **kwargs, ) -> list: output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN - if "openai" in tokenizer.name_or_path: + if "openai" in getattr(tokenizer, "name_or_path", ""): prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>" else: prompt = "" diff --git a/vllm/entrypoints/openai/speech_to_text/protocol.py b/vllm/entrypoints/openai/speech_to_text/protocol.py index ed32db2f0..a8d978e33 100644 --- a/vllm/entrypoints/openai/speech_to_text/protocol.py +++ b/vllm/entrypoints/openai/speech_to_text/protocol.py @@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel): stream_include_usage: bool | None = False stream_continuous_usage_stats: bool | None = False - vllm_xargs: dict[str, str | int | float] | None = Field( + vllm_xargs: dict[str, str | int | float | bool] | None = Field( default=None, description=( "Additional request parameters with string or " diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index d9fb78b5c..a3d3e2198 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -365,6 +365,7 @@ def build_enc_dec_inputs( encoder_inputs: SingletonInputs, decoder_inputs: SingletonInputs | None, decoder_start_token_id: int, + skip_decoder_start_token: bool = False, ) -> EncoderDecoderInputs: enc_inputs = _validate_enc_inputs(encoder_inputs) @@ -396,10 +397,11 @@ def build_enc_dec_inputs( else: assert_never(enc_inputs) - dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation( - dec_inputs_new["prompt_token_ids"], - decoder_start_token_id, - ) + if not skip_decoder_start_token: + dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation( + dec_inputs_new["prompt_token_ids"], + decoder_start_token_id, + ) if cache_salt := enc_inputs.get("cache_salt"): dec_inputs_new["cache_salt"] = cache_salt diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index b67493932..a722bb3bf 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -261,6 +261,15 @@ class InputPreprocessor: encoder_prompt = prompt["encoder_prompt"] decoder_prompt = prompt["decoder_prompt"] + skip_decoder_start_token = False + if self.renderer.mm_processor is not None: + from vllm.multimodal.processing import EncDecMultiModalProcessor + + if isinstance(self.renderer.mm_processor, EncDecMultiModalProcessor): + skip_decoder_start_token = ( + self.renderer.mm_processor.skip_decoder_start_token + ) + return build_enc_dec_inputs( encoder_inputs=self._prompt_to_llm_inputs( encoder_prompt, @@ -275,6 +284,7 @@ class InputPreprocessor: ) ), decoder_start_token_id=self.renderer.get_dec_start_token_id(), + skip_decoder_start_token=skip_decoder_start_token, ) def _process_decoder_only_prompt( diff --git a/vllm/model_executor/models/cohere_asr.py b/vllm/model_executor/models/cohere_asr.py new file mode 100644 index 000000000..21b38f37f --- /dev/null +++ b/vllm/model_executor/models/cohere_asr.py @@ -0,0 +1,2209 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import math +from collections.abc import Iterable, Mapping, Sequence +from typing import Literal, cast + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import PretrainedConfig + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.inputs.data import PromptType +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.attention import ( + Attention, + CrossAttention, +) +from vllm.model_executor.layers.linear import ( + ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.inputs import ( + MultiModalDataDict, + MultiModalFieldConfig, + MultiModalKwargsItems, +) +from vllm.multimodal.parse import ( + AudioProcessorItems, + MultiModalDataItems, + MultiModalDataParser, +) +from vllm.multimodal.processing import ( + BaseDummyInputsBuilder, + BaseProcessingInfo, + EncDecMultiModalProcessor, + PromptReplacement, + PromptUpdate, +) +from vllm.renderers import TokenizeParams +from vllm.transformers_utils.processors.cohere_asr import ( + INF_VAL, + CohereASRFeatureExtractor, + CohereASRProcessor, +) +from vllm.v1.attention.backend import ( + AttentionType, +) + +from .interfaces import ( + MultiModalEmbeddings, + SupportsMultiModal, + SupportsTranscription, +) +from .utils import AutoWeightsLoader, WeightsMapper, make_layers + +logger = init_logger(__name__) + +# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages + +ISO639_1_SUPPORTED_LANGS = { + "en": "English", + "fr": "French", + "de": "German", + "es": "Spanish", + "pt": "Portuguese", + "it": "Italian", + "nl": "Dutch", + "pl": "Polish", + "el": "Greek", + "ar": "Arabic", + "ko": "Korean", + "ja": "Japanese", + "vi": "Vietnamese", + "zh": "Chinese", +} + + +class CohereASRAttention(nn.Module): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + attn_type: AttentionType = AttentionType.DECODER, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + self.embed_dim = embed_dim + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + if self.total_num_heads >= tp_size: + # Number of heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_heads % tp_size == 0 + else: + # Number of heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_heads == 0 + self.num_kv_heads = max(1, self.total_num_heads // tp_size) + self.head_dim = self.embed_dim // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.attn_type = attn_type + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: " + f"{self.embed_dim} and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + + self._init_qkv(embed_dim, bias, quant_config, prefix=prefix) + + self.out_projection = RowParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_projection", + ) + if attn_type == AttentionType.ENCODER: + raise NotImplementedError( + "CohereASRAttention does not support Encoder Self-Attention yet." + ) + + elif self.attn_type == AttentionType.ENCODER_DECODER: + self.attn = CrossAttention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) + else: # AttentionType.DECODER (regular decoder self-attention) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + attn_type=self.attn_type, + ) + + def _init_qkv( + self, + embed_dim: int, + bias: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + self.qkv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + + attn_output = self.attn(q, k, v) + + output, _ = self.out_projection(attn_output) + + return output + + +class CohereASRCrossAttention(CohereASRAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + cache_config: CacheConfig | None = None, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__( + embed_dim=embed_dim, + num_heads=num_heads, + bias=bias, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + attn_type=AttentionType.ENCODER_DECODER, + ) + + def _init_qkv( + self, + embed_dim: int, + bias: bool = True, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ) -> None: + self.q_proj = ColumnParallelLinear( + input_size=embed_dim, + output_size=embed_dim, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + ) + self.kv_proj = QKVParallelLinear( + hidden_size=embed_dim, + head_size=self.head_dim, + total_num_heads=0, + total_num_kv_heads=self.total_num_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.kv_proj", + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + ) -> torch.Tensor: + q, _ = self.q_proj(hidden_states) + + # Encoder hidden states are only computed once during prefill phase. + # Afterwards, the keys and values should be available in the kv-cache. + if encoder_hidden_states is not None: + kv, _ = self.kv_proj(encoder_hidden_states) + k, v = kv.split([self.kv_size, self.kv_size], dim=-1) + else: + k = v = None + + attn_output = self.attn(q, k, v) + + output, _ = self.out_projection(attn_output) + + return output + + +# ----- Decoder START ----- +class CohereASRMLP(nn.Module): + def __init__( + self, + embed_dim: int, + ffn_dim: int, + act_fn: str, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): + super().__init__() + + self.activation_fn = get_act_fn(act_fn) + self.dense_in = ColumnParallelLinear( + input_size=embed_dim, + output_size=ffn_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc1", + ) + self.dense_out = RowParallelLinear( + input_size=ffn_dim, + output_size=embed_dim, + quant_config=quant_config, + prefix=f"{prefix}.fc2", + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense_in(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states, _ = self.dense_out(hidden_states) + return hidden_states + + +class FixedPositionalEncoding(nn.Module): + """ + Fixed positional encoding (embedding layer) from sine and cosine functions + of different frequencies according to https://arxiv.org/abs/1706.03762 + + Args: + hidden_size: size of the embeddings in the model, also known as d_model + max_sequence_length: maximum allowed length of the input sequence + """ + + def __init__(self, hidden_size: int, max_sequence_length: int = 512) -> None: + super().__init__() + + self._hidden_size = hidden_size + self._max_sequence_length = max_sequence_length + self._build_pos_enc( + hidden_size=self._hidden_size, max_sequence_length=self._max_sequence_length + ) + + def _build_pos_enc(self, hidden_size: int, max_sequence_length: int) -> None: + """Builds/replaces pre-computed positional encoding.""" + pos_enc = torch.zeros(max_sequence_length, hidden_size) + position = torch.arange(0.0, max_sequence_length).unsqueeze(1) + coef = -math.log(10000.0) / hidden_size + div_term = torch.exp(coef * torch.arange(0.0, hidden_size, 2)) + pos_enc[:, 0::2] = torch.sin(position * div_term) + pos_enc[:, 1::2] = torch.cos(position * div_term) + pos_enc.div_(math.sqrt(hidden_size)) + self.register_buffer("pos_enc", pos_enc) + + def forward(self, position_ids: torch.Tensor) -> torch.Tensor: + embeddings = torch.embedding(self.pos_enc, position_ids) + return embeddings + + +class CohereASRDecoderLayer(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config.transf_decoder["config_dict"] + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.hidden_dim = config.get("hidden_size") + self.ffn_dim = config.get("inner_size") + self.act_fn = config.get("hidden_act") + self.num_heads = config.get("num_attention_heads") + + # self_attn + self.layer_norm_1 = nn.LayerNorm(self.hidden_dim) + self.first_sub_layer = CohereASRAttention( + embed_dim=self.hidden_dim, + num_heads=self.num_heads, + attn_type=AttentionType.DECODER, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.first_sub_layer", + ) + + # cross attn to attend to encoder + self.layer_norm_2 = nn.LayerNorm(self.hidden_dim) + self.second_sub_layer = CohereASRCrossAttention( + embed_dim=self.hidden_dim, + num_heads=self.num_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.second_sub_layer", + ) + + self.layer_norm_3 = nn.LayerNorm(self.hidden_dim) + self.third_sub_layer = CohereASRMLP( + embed_dim=self.hidden_dim, + ffn_dim=self.ffn_dim, + act_fn=self.act_fn, + quant_config=quant_config, + prefix=f"{prefix}.third_sub_layer", + ) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + ) -> torch.Tensor: + residual = hidden_states + hidden_states = self.layer_norm_1(hidden_states) + hidden_states = self.first_sub_layer(hidden_states=hidden_states) + + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm_2(hidden_states) + hidden_states = self.second_sub_layer( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = residual + hidden_states + residual = hidden_states + hidden_states = self.layer_norm_3(hidden_states) + hidden_states = self.third_sub_layer(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class TransformerEmbedding(nn.Module): + def __init__( + self, + vocab_size: int, + hidden_size: int, + max_target_positions: int, + padding_idx: int, + ) -> None: + super().__init__() + self.token_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx) + self.position_embedding = FixedPositionalEncoding( + hidden_size=hidden_size, + max_sequence_length=max_target_positions, + ) + self.layer_norm = nn.LayerNorm(hidden_size) + + def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + inputs_embeds = self.token_embedding(input_ids) + positions = self.position_embedding(positions) + embeddings = inputs_embeds + positions + embeddings = self.layer_norm(embeddings) + return embeddings + + +@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1}) +class CohereASRDecoder(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + self.padding_idx = 2 + config_dict = config.transf_decoder["config_dict"] + self.max_target_positions = config_dict.get("max_sequence_length") + self.hidden_size = config_dict.get("hidden_size") + self.num_decoder_layers = config_dict.get("num_layers") + self.vocab_size = config.head["num_classes"] + + self.embedding = TransformerEmbedding( + vocab_size=self.vocab_size, + hidden_size=self.hidden_size, + max_target_positions=self.max_target_positions, + padding_idx=self.padding_idx, + ) + + self.start_layer, self.end_layer, self.layers = make_layers( + self.num_decoder_layers, + lambda prefix: CohereASRDecoderLayer( + vllm_config=vllm_config, prefix=f"{prefix}.layers" + ), + prefix=f"{prefix}.layers", + ) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_hidden_states: torch.Tensor | None, + ) -> torch.Tensor: + hidden_states = self.get_input_embeddings(input_ids, positions) + for decoder_layer in self.layers: + hidden_states = decoder_layer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states = self.final_layer_norm(hidden_states) + return hidden_states + + def get_input_embeddings( + self, input_ids: torch.Tensor, positions: torch.Tensor + ) -> torch.Tensor: + return self.embedding(input_ids, positions) + + +# ----- Decoder END ----- + + +# ----- Encoder START ----- +class MaskedConvSequential(nn.Sequential): + def forward( + self, x: torch.Tensor, lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + x = x.unsqueeze(1) # (batch, 1, time, features) + current_lengths = lengths.clone().float() + mask = self._create_mask(x, current_lengths.long()) + + # Process through each layer with mask propagation + for i, layer in enumerate(self): + # Apply current mask before layer + x = self.apply_channel_mask(x, mask) + + # Apply layer + x = layer(x) + + # Update lengths for stride operations with proper padding + if hasattr(layer, "stride") and layer.stride != (1, 1): + if hasattr(layer, "_left_padding"): + padding = ( + layer._left_padding, + layer._right_padding, + ) # CausalConv2D + else: + padding = layer.padding + current_lengths = self.calculate_conv_output_size( + current_lengths, layer.kernel_size[0], layer.stride[0], padding + ) + mask = self._create_mask(x, current_lengths.long()) + + # Final masking + x = self.apply_channel_mask(x, mask) + return x, current_lengths.long() + + def _create_mask(self, tensor: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor: + """Create broadcastable mask from per-sample lengths. + + Returns a (B, 1, T, 1) mask that broadcasts over channels and + features without materializing a full (B, C, T, F) tensor. + """ + batch_size, channels, time, features = tensor.shape + time_mask = torch.arange(time, device=tensor.device).expand( + batch_size, time + ) < lengths.unsqueeze(1) + return time_mask.to(tensor.dtype).unsqueeze(1).unsqueeze(-1) + + def apply_channel_mask( + self, tensor: torch.Tensor, mask: torch.Tensor + ) -> torch.Tensor: + """Apply mask in-place via broadcasting. + + tensor: (B, C, T, F), mask: (B, 1, T, 1) + """ + tensor.mul_(mask) + return tensor + + def calculate_conv_output_size( + self, + input_size: torch.Tensor, + kernel_size: int, + stride: int, + padding: tuple[int, int], + ): + """Calculate exact output size after convolution.""" + return (input_size + padding[0] + padding[1] - kernel_size) // stride + 1 + + +class ConvSubsampling(nn.Module): + def __init__( + self, + subsampling: str, + subsampling_factor: int, + feat_in: int, + feat_out: int, + conv_channels: int, + subsampling_conv_chunking_factor: int = 1, + activation: nn.Module | None = None, + is_causal: bool = False, + ) -> None: + super().__init__() + if activation is None: + activation = nn.ReLU() + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError( + "subsampling_conv_chunking_factor should be -1, 1, or a power of 2" + ) + + in_channels = 1 + layers = [] + + assert subsampling == "dw_striding" + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + assert not is_causal + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + # [1, T, num_melspec] -> [conv_channels, T//2, num_melspec//2] + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + # [conv_channels, T//2^i, num_melspec//2^i] -> + # [conv_channels, T//2^(i+1), num_melspec//2^(i+1)] + # depthwise conv + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ) + ) + + # [conv_channels, T//2^(i+1), num_melspec//2^(i+1)] + # -> [conv_channels, T//2^(i+1), num_melspec//2^(i+1)] + # pointwise conv + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) + ) + layers.append(activation) + in_channels = conv_channels + + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = self.calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + + # reshape: + # [conv_channels, T//sub_factor, num_melspec//sub_factor] + # -> [T//sub_factor, conv_channels * (num_melspec//sub_factor)] + # mlp: + # [T//sub_factor, conv_channels * (num_melspec//sub_factor)] + # -> [T//sub_factor, feat_out] + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) + self.conv2d_subsampling = True + self.conv = MaskedConvSequential(*layers) + + def calc_length( + self, + lengths: torch.Tensor, + all_paddings: int, + kernel_size: int, + stride: int, + ceil_mode: bool, + repeat_num: int = 1, + ) -> torch.Tensor: + """Calculates the output length of a Tensor passed + through a convolution or max pooling layer""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths) + return lengths.to(dtype=torch.int) + + def forward( + self, x: torch.Tensor, lengths: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + x, lengths = self.conv(x, lengths) + + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + return x, lengths + + +class PositionalEncoding(torch.nn.Module): + """Fixed sinusoidal positional encoding. + Args: + d_model (int): embedding dim + max_len (int): maximum input length + xscale (bool): whether to scale the input by sqrt(d_model) + """ + + def __init__( + self, d_model: int, max_len: int = 5000, xscale: float | None = None + ) -> None: + super().__init__() + self.d_model = d_model + self.xscale = xscale + self.max_len = max_len + + def create_pe(self, positions: torch.Tensor, dtype: torch.dtype) -> None: + pos_length = positions.size(0) + pe = torch.zeros(pos_length, self.d_model, device=positions.device) + div_term = torch.exp( + torch.arange( + 0, self.d_model, 2, dtype=torch.float32, device=positions.device + ) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(positions * div_term) + pe[:, 1::2] = torch.cos(positions * div_term) + pe = pe.unsqueeze(0).to(dtype) + if hasattr(self, "pe"): + self.pe = pe + else: + self.register_buffer("pe", pe, persistent=False) + + def forward( + self, x: torch.Tensor, cache_len: int = 0 + ) -> tuple[torch.Tensor, torch.Tensor]: + """Adds positional encoding. + Args: + x (torch.Tensor): Input. Its shape is (batch, time, feature_size) + cache_len (int): the size of the cache which is used to shift positions + Returns: + x+pos_emb (torch.Tensor): Its shape is (batch, time, feature_size) + pos_emb (torch.Tensor): Its shape is (1, time, feature_size) + """ + input_len = x.size(1) + cache_len + if self.xscale: + x = x * self.xscale + pos_emb = self.pe[:, :input_len] + x = x + pos_emb + return x, pos_emb + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding for TransformerXL's layers + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): embedding dim + max_len (int): maximum input length + xscale (bool): whether to scale the input by sqrt(d_model) + """ + + def extend_pe(self, length: int, device: torch.device, dtype: torch.dtype) -> None: + """Reset and extend the positional encodings if needed.""" + needed_size = 2 * length - 1 + if hasattr(self, "pe") and self.pe.size(1) >= needed_size: + return + positions = torch.arange( + length - 1, -length, -1, dtype=torch.float32, device=device + ).unsqueeze(1) + self.create_pe(positions=positions, dtype=dtype) + + def forward( + self, x: torch.Tensor, cache_len: int = 0 + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute positional encoding. + Args: + x (torch.Tensor): Input. Its shape is (batch, time, feature_size) + cache_len (int): the size of the cache which is used to shift positions + Returns: + x (torch.Tensor): Its shape is (batch, time, feature_size) + pos_emb (torch.Tensor): Its shape is (1, time, feature_size) + """ + + if self.xscale: + x = x * self.xscale + + input_len = x.size(1) + cache_len + center_pos = self.pe.size(1) // 2 + 1 + start_pos = center_pos - input_len + end_pos = center_pos + input_len - 1 + pos_emb = self.pe[:, start_pos:end_pos] + + return x, pos_emb + + +class Swish(nn.SiLU): + """ + Swish activation function introduced in 'https://arxiv.org/abs/1710.05941' + Mathematically identical to SiLU. See note in nn.SiLU for references. + """ + + +class ConformerFeedForward(nn.Module): + """ + feed-forward module of Conformer model. + use_bias (bool): Apply bias to all Linear and Conv1d + layers to improve activation flow and stabilize + training of huge models. + """ + + def __init__( + self, + d_model: int, + d_ff: int, + activation: nn.Module | None = None, + use_bias: bool = True, + ) -> None: + super().__init__() + if activation is None: + activation = Swish() + self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias) + self.activation = activation + self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = self.activation(x) + x = self.linear2(x) + return x + + +class CausalConv1D(nn.Conv1d): + """ + A causal version of nn.Conv1d where each step would + have limited access to locations on its right or left. + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set + automatically to make it a causal convolution where + each location would not see any steps on its right. + + If padding is set as a list (size of 2), then + padding[0] would be used as left padding and + padding[1] as right padding. It would make it possible + to control the number of steps to be accessible on the + right and left. This mode is not supported when + stride > 1. padding[0]+padding[1] should be equal to + (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: str | int = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError("No striding allowed for non-symmetric convolutions!") + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif ( + isinstance(padding, list) + and len(padding) == 2 + and padding[0] + padding[1] == kernel_size - 1 + ): + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.pad(x, pad=(self._left_padding, self._right_padding)) + return super().forward(x) + + +class ConformerConvolution(nn.Module): + """The convolution module for the Conformer model. + Args: + d_model (int): hidden dimension + kernel_size (int): kernel size for depthwise convolution + pointwise_activation (str): name of the activation + function to be used for the pointwise conv. + Note that Conformer uses a special key `glu_` + which is treated as the original default from + the paper. + use_bias (bool): Use bias in all Linear and Conv1d + layers to improve activation flow and stabilize + training of huge models. Defaults to True + """ + + def __init__( + self, + d_model: int, + kernel_size: int, + norm_type: str = "batch_norm", + conv_context_size: int | None = None, + pointwise_activation: str = "glu_", + use_bias: bool = True, + ) -> None: + super().__init__() + assert (kernel_size - 1) % 2 == 0 + + if conv_context_size is None: + conv_context_size = (kernel_size - 1) // 2 + + assert pointwise_activation == "glu_" + dw_conv_input_dim = d_model + + self.pointwise_conv1 = nn.Conv1d( + in_channels=d_model, + out_channels=d_model * 2, + kernel_size=1, + stride=1, + padding=0, + bias=use_bias, + ) + + self.depthwise_conv = CausalConv1D( + in_channels=dw_conv_input_dim, + out_channels=dw_conv_input_dim, + kernel_size=kernel_size, + stride=1, + padding=conv_context_size, + groups=dw_conv_input_dim, + bias=use_bias, + ) + + assert norm_type == "batch_norm" + self.batch_norm = nn.BatchNorm1d(dw_conv_input_dim) + + self.activation = Swish() + self.pointwise_conv2 = nn.Conv1d( + in_channels=dw_conv_input_dim, + out_channels=d_model, + kernel_size=1, + stride=1, + padding=0, + bias=use_bias, + ) + + def forward( + self, x: torch.Tensor, pad_mask: torch.Tensor | None = None + ) -> torch.Tensor: + x = x.transpose(1, 2) + x = self.pointwise_conv1(x) + + x = nn.functional.glu(x, dim=1) + + if pad_mask is not None: + x = x.masked_fill(pad_mask.unsqueeze(1), 0.0) + + x = self.depthwise_conv(x) + + x = self.batch_norm(x) + + x = self.activation(x) + x = self.pointwise_conv2(x) + x = x.transpose(1, 2) + return x + + +class CohereASRMultiHeadAttention(nn.Module): + """Multi-Head Attention layer of Transformer. + Args: + n_head (int): number of heads + n_feat (int): size of the features + use_bias (bool): whether to remove bias in linear and conv layers + """ + + def __init__( + self, + n_head: int, + n_feat: int, + use_bias: bool = True, + ) -> None: + """Construct an MultiHeadedAttention object.""" + super().__init__() + + assert n_feat % n_head == 0 + self.d_k = n_feat // n_head + self.s_d_k = math.sqrt(self.d_k) + self.h = n_head + self.linear_q = nn.Linear(n_feat, n_feat, bias=use_bias) + self.linear_k = nn.Linear(n_feat, n_feat, bias=use_bias) + self.linear_v = nn.Linear(n_feat, n_feat, bias=use_bias) + self.linear_out = nn.Linear(n_feat, n_feat, bias=use_bias) + + def forward_qkv( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Transforms query, key and value. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value (torch.Tensor): (batch, time2, size) + returns: + q (torch.Tensor): (batch, head, time1, size) + k (torch.Tensor): (batch, head, time2, size) + v (torch.Tensor): (batch, head, time2, size) + """ + n_batch = query.size(0) + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) + k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) + v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + return q, k, v + + def forward_attention( + self, + value: torch.Tensor, + scores: torch.Tensor, + mask: torch.Tensor | None, + ) -> torch.Tensor: + """Compute attention context vector. + Args: + value (torch.Tensor): (batch, time2, size) + scores(torch.Tensor): (batch, time1, time2) + mask(torch.Tensor): (batch, time1, time2) + returns: + value (torch.Tensor): transformed `value` + (batch, time2, d_model) weighted by the + attention scores + """ + n_batch = value.size(0) + if mask is not None: + mask = mask.unsqueeze(1) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -INF_VAL) + attn = torch.softmax(scores, dim=-1).masked_fill( + mask, 0.0 + ) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + + x = torch.matmul(attn, value) # (batch, head, time1, d_k) + x = x.transpose(1, 2).reshape( + n_batch, -1, self.h * self.d_k + ) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor | None, + pos_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compute 'Scaled Dot Product Attention'. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + + returns: + output (torch.Tensor): transformed `value` + (batch, time1, d_model) weighted by the + query dot key attention + """ + q, k, v = self.forward_qkv(query, key, value) + + scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k + return self.forward_attention(v, scores, mask) + + +class RelPositionMultiHeadAttention(CohereASRMultiHeadAttention): + """Multi-Head Attention layer of Transformer-XL with + support of relative positional encoding. + Paper: https://arxiv.org/abs/1901.02860 + Args: + n_head (int): number of heads + n_feat (int): size of the features + use_bias (bool): whether to apply bias in linear + and conv layers of MultiHeadAttention + """ + + def __init__( + self, + n_head: int, + n_feat: int, + pos_bias_u: nn.Parameter | torch.Tensor | None, + pos_bias_v: nn.Parameter | torch.Tensor | None, + use_bias: bool = True, + ) -> None: + """Construct an RelPositionMultiHeadedAttention object.""" + super().__init__( + n_head=n_head, + n_feat=n_feat, + use_bias=use_bias, + ) + # linear transformation for positional encoding + self.linear_pos = nn.Linear(n_feat, n_feat, bias=False) + # these two learnable biases are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + if pos_bias_u is None or pos_bias_v is None: + self.pos_bias_u = nn.Parameter( + torch.zeros(self.h, self.d_k), requires_grad=False + ) + self.pos_bias_v = nn.Parameter( + torch.zeros(self.h, self.d_k), requires_grad=False + ) + else: + self.pos_bias_u = pos_bias_u + self.pos_bias_v = pos_bias_v + + def rel_shift(self, x: torch.Tensor) -> torch.Tensor: + """Compute relative positional encoding. + Args: + x (torch.Tensor): (batch, nheads, time, 2*time-1) + """ + b, h, qlen, pos_len = x.size() # (b, h, t1, t2) + # need to add a column of zeros on the left side of + # last dimension to perform the relative shifting + x = torch.nn.functional.pad(x, pad=(1, 0)) # (b, h, t1, t2+1) + x = x.view(b, h, -1, qlen) # (b, h, t2+1, t1) + # need to drop the first row + x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2) + return x + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor | None, + pos_emb: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compute 'Scaled Dot Product Attention' with rel. positional encoding. + Args: + query (torch.Tensor): (batch, time1, size) + key (torch.Tensor): (batch, time2, size) + value(torch.Tensor): (batch, time2, size) + mask (torch.Tensor): (batch, time1, time2) + pos_emb (torch.Tensor) : (batch, time1, size) + + Returns: + output (torch.Tensor): transformed `value` + (batch, time1, d_model) weighted by the + query dot key attention + """ + q, k, v = self.forward_qkv(query, key, value) + q = q.transpose(1, 2) # (batch, time1, head, d_k) + + n_batch_pos = pos_emb.size(0) + p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k) + p = p.transpose(1, 2) # (batch, head, time1, d_k) + + # (batch, head, time1, d_k) + q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2) + # (batch, head, time1, d_k) + q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2) + + # compute attention score + # first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # (batch, head, time1, time2) + + # compute matrix b and matrix d + # (batch, head, time1, time2) + matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1)) + matrix_bd = self.rel_shift(matrix_bd) + + # drops extra elements in the matrix_bd to match the matrix_ac's size + matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1)) + matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)] + scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2) + return self.forward_attention(v, scores, mask) + + +class ConformerLayer(torch.nn.Module): + """A single block of the Conformer encoder. + + Args: + d_model (int): input dimension of + MultiheadAttentionMechanism and + PositionwiseFeedForward + d_ff (int): hidden dimension of + PositionwiseFeedForward + self_attention_model (str): type of the attention + layer and positional encoding + n_heads (int): number of heads for multi-head + attention + conv_kernel_size (int): kernel size for depthwise + convolution in convolution module + use_bias (bool): Apply bias to all Linear and + Conv1d layers from each ConformerLayer to + improve activation flow and stabilize training + of huge models. Defaults to True. + """ + + def __init__( + self, + d_model: int, + d_ff: int, + self_attention_model: str = "rel_pos", + n_heads: int = 4, + conv_kernel_size: int = 31, + conv_norm_type: str = "batch_norm", + conv_context_size: int | None = None, + pos_bias_u: nn.Parameter | torch.Tensor | None = None, + pos_bias_v: nn.Parameter | torch.Tensor | None = None, + att_context_size: list[int] | None = None, + use_bias: bool = True, + ) -> None: + super().__init__() + if att_context_size is None: + att_context_size = [-1, -1] + + self.self_attention_model = self_attention_model + self.fc_factor = 0.5 + + # first feed forward module + self.norm_feed_forward1 = nn.LayerNorm(d_model) + self.feed_forward1 = ConformerFeedForward( + d_model=d_model, d_ff=d_ff, use_bias=use_bias + ) + + # convolution module + self.norm_conv = nn.LayerNorm(d_model) + self.conv = ConformerConvolution( + d_model=d_model, + kernel_size=conv_kernel_size, + norm_type=conv_norm_type, + conv_context_size=conv_context_size, + use_bias=use_bias, + ) + + # multi-headed self-attention module + self.norm_self_att = nn.LayerNorm(d_model) + + assert self_attention_model == "rel_pos" + + self.self_attn = RelPositionMultiHeadAttention( + n_head=n_heads, + n_feat=d_model, + pos_bias_u=pos_bias_u, + pos_bias_v=pos_bias_v, + use_bias=use_bias, + ) + + # second feed forward module + self.norm_feed_forward2 = nn.LayerNorm(d_model) + self.feed_forward2 = ConformerFeedForward( + d_model=d_model, d_ff=d_ff, use_bias=use_bias + ) + + self.norm_out = nn.LayerNorm(d_model) + + def forward( + self, + x: torch.Tensor, + att_mask: torch.Tensor | None = None, + pos_emb: torch.Tensor | None = None, + pad_mask: torch.Tensor | None = None, + ) -> torch.Tensor: + """ + Args: + x (torch.Tensor): input signals (B, T, d_model) + att_mask (torch.Tensor): attention masks(B, T, T) + pos_emb (torch.Tensor): (L, 1, d_model) + pad_mask (torch.tensor): padding mask + Returns: + x (torch.Tensor): (B, T, d_model) + """ + residual = x + x = self.norm_feed_forward1(x) + x = self.feed_forward1(x) + residual = residual + x * self.fc_factor + + x = self.norm_self_att(residual) + if self.self_attention_model == "rel_pos": + x = self.self_attn( + query=x, + key=x, + value=x, + mask=att_mask, + pos_emb=pos_emb, + ) + elif self.self_attention_model == "rel_pos_local_attn": + x = self.self_attn( + query=x, + key=x, + value=x, + pad_mask=pad_mask, + pos_emb=pos_emb, + ) + elif self.self_attention_model == "abs_pos": + x = self.self_attn(query=x, key=x, value=x, mask=att_mask) + else: + x = None + + residual = residual + x + + x = self.norm_conv(residual) + x = self.conv(x, pad_mask=pad_mask) + residual = residual + x + + x = self.norm_feed_forward2(residual) + x = self.feed_forward2(x) + residual = residual + x * self.fc_factor + + x = self.norm_out(residual) + + return x + + +class ConformerEncoder(nn.Module): + """ + The encoder for ASR model of Conformer. + Based on this paper: + 'Conformer: Convolution-augmented Transformer for + Speech Recognition' by Anmol Gulati et al. + https://arxiv.org/abs/2005.08100 + """ + + def __init__(self, *, vllm_config: VllmConfig): + super().__init__() + + self.hf_config = vllm_config.model_config.hf_config + + feat_in = self.hf_config.encoder["feat_in"] + n_layers = self.hf_config.encoder["n_layers"] + d_model = self.hf_config.encoder["d_model"] + feat_out = self.hf_config.encoder["feat_out"] + causal_downsampling = self.hf_config.encoder["causal_downsampling"] + subsampling = self.hf_config.encoder["subsampling"] + subsampling_factor = self.hf_config.encoder["subsampling_factor"] + subsampling_conv_chunking_factor = self.hf_config.encoder.get( + "subsampling_conv_chunking_factor", 1 + ) + subsampling_conv_channels = self.hf_config.encoder["subsampling_conv_channels"] + ff_expansion_factor = self.hf_config.encoder["ff_expansion_factor"] + self_attention_model = self.hf_config.encoder["self_attention_model"] + n_heads = self.hf_config.encoder["n_heads"] + att_context_size = self.hf_config.encoder["att_context_size"] + att_context_probs = self.hf_config.encoder.get("att_context_probs", None) + att_context_style = self.hf_config.encoder.get("att_context_style", "regular") + xscaling = self.hf_config.encoder["xscaling"] + untie_biases = self.hf_config.encoder["untie_biases"] + pos_emb_max_len = self.hf_config.encoder["pos_emb_max_len"] + conv_kernel_size = self.hf_config.encoder["conv_kernel_size"] + conv_norm_type = self.hf_config.encoder["conv_norm_type"] + conv_context_size = self.hf_config.encoder["conv_context_size"] + use_bias = self.hf_config.encoder.get("use_bias", True) + + d_ff = d_model * ff_expansion_factor + self.d_model = d_model + self._feat_in = feat_in + self.att_context_style = att_context_style + self.subsampling_factor = subsampling_factor + + self.self_attention_model = self_attention_model + + # Setting up the att_context_size + ( + _, + self.att_context_size, + _, + self.conv_context_size, + ) = self._calc_context_sizes( + att_context_style=att_context_style, + att_context_size=att_context_size, + att_context_probs=att_context_probs, + conv_context_size=conv_context_size, + conv_kernel_size=conv_kernel_size, + ) + + if xscaling: + self.xscale = math.sqrt(d_model) + else: + self.xscale = None + + # Subsampling + if subsampling_conv_channels == -1: + subsampling_conv_channels = d_model + assert subsampling and subsampling_factor > 1 and subsampling == "dw_striding" + + self.pre_encode = ConvSubsampling( + subsampling=subsampling, + subsampling_factor=subsampling_factor, + feat_in=feat_in, + feat_out=d_model, + conv_channels=subsampling_conv_channels, + subsampling_conv_chunking_factor=subsampling_conv_chunking_factor, + activation=nn.ReLU(True), + is_causal=causal_downsampling, + ) + + self._feat_out = d_model + + # Biases for relative positional encoding + if not untie_biases and self_attention_model == "rel_pos": + d_head = d_model // n_heads + # Register as buffers instead of parameters since they're not trainable + # and need to respect dtype during weight loading + self.register_buffer( + "pos_bias_u", torch.zeros(n_heads, d_head), persistent=True + ) + self.register_buffer( + "pos_bias_v", torch.zeros(n_heads, d_head), persistent=True + ) + pos_bias_u = self.pos_bias_u + pos_bias_v = self.pos_bias_v + else: + pos_bias_u = None + pos_bias_v = None + + # Positional encodings + self.pos_emb_max_len = pos_emb_max_len + assert self_attention_model == "rel_pos" + self.pos_enc = RelPositionalEncoding( + d_model=d_model, + max_len=pos_emb_max_len, + xscale=self.xscale, + ) + + self.layers = nn.ModuleList() + for i in range(n_layers): + layer = ConformerLayer( + d_model=d_model, + d_ff=d_ff, + self_attention_model=self_attention_model, + n_heads=n_heads, + conv_kernel_size=conv_kernel_size, + conv_norm_type=conv_norm_type, + conv_context_size=self.conv_context_size, + pos_bias_u=pos_bias_u, + pos_bias_v=pos_bias_v, + att_context_size=self.att_context_size, + use_bias=use_bias, + ) + self.layers.append(layer) + + if feat_out > 0 and feat_out != self._feat_out: + self.out_proj = nn.Linear(self._feat_out, feat_out) + self._feat_out = feat_out + else: + self.out_proj = None + self._feat_out = d_model + self.set_max_audio_length(self.pos_emb_max_len) + + def get_num_encoder_cross_attn_tokens(self, num_encoder_input_tokens: int) -> int: + num_encoder_cross_attn_tokens = math.ceil( + num_encoder_input_tokens / self.subsampling_factor + ) + return num_encoder_cross_attn_tokens + + def set_max_audio_length(self, max_audio_length: int) -> None: + """ + Sets maximum input length. + Pre-calculates internal seq_range mask. + + Args: + max_audio_length (int): New maximum sequence length. + """ + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + self.pos_enc.extend_pe(max_audio_length, device, dtype) + + def forward( + self, + audio_signal: torch.Tensor, + length: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if audio_signal.shape[-2] != self._feat_in: + raise ValueError( + f"audio_signal should have shape " + f"(batch, {self._feat_in}, n_frame) but " + f"got last dimension " + f"{audio_signal.shape[-2]}." + ) + + return self.forward_internal( + audio_signal, + length, + ) + + def forward_internal( + self, + audio_signal: torch.Tensor, + length: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if length is None: + length = audio_signal.new_full( + (audio_signal.size(0),), + audio_signal.size(-1), + dtype=torch.int64, + device=audio_signal.device, + ) + + cur_att_context_size = self.att_context_size + audio_signal = torch.transpose(audio_signal, 1, 2) + + audio_signal, length = self.pre_encode(x=audio_signal, lengths=length) + length = length.to(torch.int64) + + max_audio_length = audio_signal.size(1) + + padding_length = length + + audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=0) + + pad_mask, att_mask = self._create_masks( + att_context_size=cur_att_context_size, + padding_length=padding_length, + max_audio_length=max_audio_length, + offset=None, + device=audio_signal.device, + ) + + for lth, layer in enumerate(self.layers): + audio_signal = layer( + x=audio_signal, + att_mask=att_mask, + pos_emb=pos_emb, + pad_mask=pad_mask, + ) + + if self.out_proj is not None: + audio_signal = self.out_proj(audio_signal) + + audio_signal = torch.transpose(audio_signal, 1, 2) + length = length.to(dtype=torch.int64) + + return audio_signal, length + + def _create_masks( + self, + att_context_size: list[int], + padding_length: torch.Tensor, + max_audio_length: int, + offset: torch.Tensor | None, + device: torch.device, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + if self.self_attention_model != "rel_pos_local_attn": + att_mask = torch.ones( + 1, max_audio_length, max_audio_length, dtype=torch.bool, device=device + ) + + if self.att_context_style == "regular": + if att_context_size[0] >= 0: + att_mask = att_mask.triu(diagonal=-att_context_size[0]) + if att_context_size[1] >= 0: + att_mask = att_mask.tril(diagonal=att_context_size[1]) + elif self.att_context_style == "chunked_limited": + # When right context is unlimited, just the + # left side of masking needs to get updated + if att_context_size[1] == -1: + if att_context_size[0] >= 0: + att_mask = att_mask.triu(diagonal=-att_context_size[0]) + else: + chunk_size = att_context_size[1] + 1 + # left_chunks_num specifies the number + # of chunks to be visible by each chunk + # on the left side + if att_context_size[0] >= 0: + left_chunks_num = att_context_size[0] // chunk_size + else: + left_chunks_num = 10000 + + chunk_idx = torch.arange( + 0, max_audio_length, dtype=torch.int, device=att_mask.device + ) + chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="trunc") + diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0) + chunked_limited_mask = torch.logical_and( + torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0) + ) + att_mask = torch.logical_and( + att_mask, chunked_limited_mask.unsqueeze(0) + ) + else: + att_mask = None + + # pad_mask is the masking to be used to ignore paddings + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(-1) + + if offset is not None: + pad_mask_off = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) >= offset.unsqueeze(-1) + pad_mask = pad_mask_off.logical_and(pad_mask) + + if att_mask is not None: + # pad_mask_for_att_mask is the mask which helps to ignore paddings + pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat( + [1, max_audio_length, 1] + ) + pad_mask_for_att_mask = torch.logical_and( + pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2) + ) + # att_mask is the masking to be used by MHA + # layers to ignore tokens not supposed to be + # visible + att_mask = att_mask[:, :max_audio_length, :max_audio_length] + # paddings should also get ignored, so + # pad_mask_for_att_mask is used to ignore their + # corresponding scores + att_mask = torch.logical_and( + pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device) + ) + att_mask = ~att_mask + + pad_mask = ~pad_mask + return pad_mask, att_mask + + def _calc_context_sizes( + self, + att_context_size: list[int] | list[list[int]] | None, + att_context_probs: list[float] | None, + att_context_style: str, + conv_context_size: list[int] | str | None, + conv_kernel_size: int, + ) -> tuple[list[list[int]], list[int], list[float], list[int]]: + # convert att_context_size to a standard list of lists + if att_context_size: + att_context_size_all = list(att_context_size) + if isinstance(att_context_size_all[0], int): + att_context_size_all = [att_context_size_all] + for i, att_cs in enumerate(att_context_size_all): + if att_context_style == "chunked_limited": + if att_cs[0] > 0 and att_cs[0] % (att_cs[1] + 1) > 0: + raise ValueError( + f"att_context_size[{i}][0] % " + f"(att_context_size[{i}][1]" + f" + 1) should be zero!" + ) + if att_cs[1] < 0 and len(att_context_size_all) <= 1: + raise ValueError( + f"Right context " + f"(att_context_size[{i}][1])" + f" can not be unlimited for" + f" chunked_limited style!" + ) + else: + att_context_size_all = [[-1, -1]] + + if att_context_probs: + if len(att_context_probs) != len(att_context_size_all): + raise ValueError( + "The size of the att_context_probs " + "should be the same as att_context_size." + ) + att_context_probs = list(att_context_probs) + if sum(att_context_probs) != 1: + raise ValueError( + "The sum of numbers in " + "att_context_probs should be equal " + "to one to be a distribution." + ) + else: + att_context_probs = [1.0 / len(att_context_size_all)] * len( + att_context_size_all + ) + + if conv_context_size is not None: + if not isinstance(conv_context_size, list) and not isinstance( + conv_context_size, str + ): + raise ValueError( + "Invalid conv_context_size! It should " + "be the string 'causal' or a list of " + "two integers." + ) + if conv_context_size == "causal": + conv_context_size = [conv_kernel_size - 1, 0] + else: + total = conv_context_size[0] + conv_context_size[1] + 1 + if total != conv_kernel_size: + raise ValueError( + f"Invalid conv_context_size: {self.conv_context_size}!" + ) + else: + conv_context_size = [ + (conv_kernel_size - 1) // 2, + (conv_kernel_size - 1) // 2, + ] + return ( + att_context_size_all, + att_context_size_all[0], + att_context_probs, + conv_context_size, + ) + + +# ----- Encoder END ----- + + +class CohereASRModel(nn.Module): + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.encoder = ConformerEncoder(vllm_config=vllm_config) + + self.decoder = CohereASRDecoder( + vllm_config=vllm_config, prefix=f"{prefix}.decoder" + ) + + if self.encoder.d_model != self.decoder.hidden_size: + self.encoder_decoder_proj = torch.nn.Linear( + self.encoder.d_model, self.decoder.hidden_size + ) + + def forward( + self, + input_ids: torch.Tensor | None, + positions: torch.Tensor, + encoder_outputs: list[torch.Tensor], + ) -> torch.Tensor: + enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None + decoder_outputs = self.decoder( + input_ids=input_ids, + positions=positions, + encoder_hidden_states=enc_states, + ) + + return decoder_outputs + + def get_encoder_outputs( + self, + input_features: torch.Tensor | list[torch.Tensor] | None, + seq_lens: torch.Tensor | None, + ) -> torch.Tensor | None: + if input_features is None: + return None + + if isinstance(input_features, torch.Tensor): + encoder_input_length = seq_lens + out, encoder_output_length = self.encoder( + input_features, length=encoder_input_length + ) # B x D x T + out = out.permute(0, 2, 1) + + if hasattr(self, "encoder_decoder_proj"): + out = self.encoder_decoder_proj(out) + + # Convert padded tensor to packed + outs = [] + for i, feat in enumerate(out): + feat_len = encoder_output_length[i] + outs.append(feat[:feat_len, :]) + + return outs + else: + raise NotImplementedError("List input_features not supported") + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".first_sub_layer.qkv_proj", ".first_sub_layer.query_net", "q"), + (".first_sub_layer.qkv_proj", ".first_sub_layer.key_net", "k"), + (".first_sub_layer.qkv_proj", ".first_sub_layer.value_net", "v"), + (".second_sub_layer.kv_proj", ".second_sub_layer.key_net", "k"), + (".second_sub_layer.kv_proj", ".second_sub_layer.value_net", "v"), + ] + params_dict = dict(self.named_parameters()) + buffers_dict = dict(self.named_buffers()) + params_dict.update(buffers_dict) + + loaded_params: set[str] = set() + for name, loaded_weight in weights: + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + # if name.endswith(".bias") and name not in params_dict: + # continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + + # Convert buffer dtype to match loaded weight for pos_bias tensors + if "pos_bias" in name and param.dtype != loaded_weight.dtype: + logger.info( + "Converting buffer %s dtype from %s to %s for loading.", + name, + param.dtype, + loaded_weight.dtype, + ) + param.data = param.data.to(loaded_weight.dtype) + + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CohereASRProcessingInfo(BaseProcessingInfo): + def get_hf_config(self) -> PretrainedConfig: + return self.ctx.get_hf_config() + + def get_default_tok_params(self) -> TokenizeParams: + # Special tokens should be provided by the user based on the + # task and language of their request. Also needed to avoid + # appending an EOS token to the prompt which disrupts generation. + return super().get_default_tok_params().with_kwargs(add_special_tokens=False) + + def get_hf_processor(self, **kwargs: object) -> CohereASRProcessor: + if not hasattr(self, "_cached_hf_processor"): + hf_config = self.get_hf_config() + preproc = hf_config.preprocessor + + sample_rate = preproc.get("sample_rate", 16000) + window_size = preproc.get("window_size", 0.02) + window_stride = preproc.get("window_stride", 0.01) + + feature_extractor = CohereASRFeatureExtractor( + feature_size=preproc.get("features", 64), + sampling_rate=sample_rate, + padding_value=preproc.get("pad_value", 0.0), + max_duration=hf_config.max_audio_clip_s, + n_window_size=int(window_size * sample_rate), + n_window_stride=int(window_stride * sample_rate), + window=preproc.get("window", "hann"), + normalize=preproc.get("normalize", "per_feature"), + n_fft=preproc.get("n_fft", None), + preemph=preproc.get("preemph", 0.97), + lowfreq=preproc.get("lowfreq", 0), + highfreq=preproc.get("highfreq", None), + log=preproc.get("log", True), + log_zero_guard_type=preproc.get("log_zero_guard_type", "add"), + log_zero_guard_value=preproc.get("log_zero_guard_value", 2**-24), + dither=preproc.get("dither", 1e-05), + pad_to=preproc.get("pad_to", 16), + frame_splicing=preproc.get("frame_splicing", 1), + exact_pad=preproc.get("exact_pad", False), + mag_power=preproc.get("mag_power", 2.0), + mel_norm=preproc.get("mel_norm", "slaney"), + stft_exact_pad=preproc.get("stft_exact_pad", False), + stft_conv=preproc.get("stft_conv", False), + device="cpu", + ) + + tokenizer = self.ctx.tokenizer + self._cached_hf_processor = CohereASRProcessor( + feature_extractor=feature_extractor, + tokenizer=tokenizer, + ) + return self._cached_hf_processor + + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"audio": 1} + + def get_data_parser(self) -> MultiModalDataParser: + feature_extractor = self.get_feature_extractor() + return MultiModalDataParser(target_sr=feature_extractor.sampling_rate) + + def get_feature_extractor(self, **kwargs: object) -> CohereASRFeatureExtractor: + hf_processor = self.get_hf_processor(**kwargs) + feature_extractor = hf_processor.feature_extractor + assert isinstance(feature_extractor, CohereASRFeatureExtractor) + return feature_extractor + + def get_num_audio_tokens(self, num_samples: int) -> int: + num_tokens = self.get_feature_extractor().get_seq_len(num_samples) + config = self.get_hf_config() + subsampling_factor = config.encoder["subsampling_factor"] + num_tokens = math.ceil(num_tokens / subsampling_factor) + return num_tokens + + +class CohereASRDummyInputsBuilder(BaseDummyInputsBuilder[CohereASRProcessingInfo]): + def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: + num_audios = mm_counts.get("audio", 0) + + return "<|startoftranscript|>" * num_audios + + def get_dummy_mm_data( + self, + seq_len: int, + mm_counts: Mapping[str, int], + mm_options=None, + mm_processor_kwargs=None, + ) -> MultiModalDataDict: + feature_extractor = self.info.get_feature_extractor() + + sampling_rate = feature_extractor.sampling_rate + audio_len = feature_extractor.max_duration * sampling_rate + num_audios = mm_counts.get("audio", 0) + + return { + "audio": self._get_dummy_audios(length=audio_len, num_audios=num_audios) + } + + +class CohereASRMultiModalProcessor(EncDecMultiModalProcessor[CohereASRProcessingInfo]): + skip_decoder_start_token: bool = True + + @property + def pad_dummy_encoder_prompt(self) -> bool: + return True + + def create_encoder_prompt( + self, + prompt: str | list[int], + mm_items: MultiModalDataItems, + ) -> str | list[int]: + return [0] + + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ): + if mm_data: + feature_extractor = self.info.get_feature_extractor(**mm_kwargs) + mm_data = dict(audio=mm_data.pop("audios")) + mm_kwargs = dict( + **mm_kwargs, + sampling_rate=feature_extractor.sampling_rate, + ) + processed_outputs = super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) + if "labels" in processed_outputs: + processed_outputs["input_ids"] = processed_outputs.pop("labels") + return processed_outputs + + def _get_mm_fields_config( + self, + hf_inputs, + hf_processor_mm_kwargs: Mapping[str, object], + ) -> Mapping[str, MultiModalFieldConfig]: + return dict( + input_features=MultiModalFieldConfig.batched("audio"), + length=MultiModalFieldConfig.batched("audio"), + ) + + def _get_prompt_updates( + self, + mm_items: MultiModalDataItems, + hf_processor_mm_kwargs: Mapping[str, object], + out_mm_kwargs: MultiModalKwargsItems, + ) -> Sequence[PromptUpdate]: + def get_audio_replacement_cohere_asr(item_idx: int): + audios = mm_items.get_items("audio", AudioProcessorItems) + audio_len = audios.get_audio_length(item_idx) + num_tokens = self.info.get_num_audio_tokens(num_samples=audio_len) + return [0] * num_tokens + + return [ + PromptReplacement( + modality="audio", + target=[0], + replacement=get_audio_replacement_cohere_asr, + ) + ] + + +@MULTIMODAL_REGISTRY.register_processor( + CohereASRMultiModalProcessor, + info=CohereASRProcessingInfo, + dummy_inputs=CohereASRDummyInputsBuilder, +) +class CohereASRForConditionalGeneration( + nn.Module, SupportsTranscription, SupportsMultiModal +): + packed_modules_mapping = { + "self_attn.qkv_proj": [ + "self_attn.q_proj", + "self_attn.k_proj", + "self_attn.v_proj", + ], + "encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"], + } + + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."} + ) + + supports_transcription_only = True + supported_languages = ISO639_1_SUPPORTED_LANGS + skip_warmup_audio_preprocessing = True + + @classmethod + def validate_language(cls, language: str | None) -> str | None: + if language is None: + logger.warning( + "Defaulting to language='en'. If you wish to transcribe " + "audio in a different language, pass the `language` field " + "in the TranscriptionRequest." + ) + language = "en" + return super().validate_language(language) + + @classmethod + def get_generation_prompt( + cls, + audio: np.ndarray, + model_config: ModelConfig, # not needed here + stt_config: SpeechToTextConfig, + language: str | None, + task_type: Literal["transcribe", "translate"], + request_prompt: str, + to_language: str | None, + ) -> PromptType: + if language is None: + raise ValueError( + "Language must be specified when creating the CohereASR prompt" + ) + + # NOTE: this function is used only by online inference and not offline inference + # CohereASR doesnt have encoder prompt + language_tag = f"<|{language}|><|{language}|>" + pnc = True # TODO(ekagra): make this configurable later + pnc_tag = "<|pnc|>" if pnc else "<|nopnc|>" + default_prompt = ( + f"<|startofcontext|><|startoftranscript|>" + f"<|emo:undefined|>{language_tag}{pnc_tag}" + f"<|noitn|><|notimestamp|><|nodiarize|>" + ) + prompt_text = request_prompt if request_prompt else default_prompt + prompt = { + "prompt": prompt_text, + "multi_modal_data": { + "audio": (audio, stt_config.sample_rate), + }, + } + + return cast(PromptType, prompt) + + @classmethod + def get_placeholder_str(cls, modality: str, i: int) -> str | None: + # Required as part of SupportsMultiModal interface. + if modality.startswith("audio"): + return None + + raise ValueError("Only audio modality is supported") + + @classmethod + def get_speech_to_text_config( + cls, model_config: ModelConfig, task_type: str + ) -> SpeechToTextConfig: + sampling_rate = model_config.hf_config.sample_rate + assert sampling_rate == 16000 + max_audio_clip_s = model_config.hf_config.max_audio_clip_s + overlap_chunk_second = model_config.hf_config.overlap_chunk_second + + return SpeechToTextConfig( + max_audio_clip_s=max_audio_clip_s, + overlap_chunk_second=overlap_chunk_second, + sample_rate=sampling_rate, + ) + + @classmethod + def get_num_audio_tokens( + cls, + audio_duration_s: float, + stt_config: SpeechToTextConfig, + model_config: ModelConfig, + ) -> int | None: + hop_length = model_config.hf_config.preprocessor.get("window_stride") + assert hop_length is not None + return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length) + + def get_num_encoder_cross_attn_tokens(self, num_encoder_input_tokens: int) -> int: + return self.model.encoder.get_num_encoder_cross_attn_tokens( + num_encoder_input_tokens + ) + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.dtype = vllm_config.model_config.dtype + + self.model = CohereASRModel(vllm_config=vllm_config, prefix=prefix) + lm_head_config = config.head + self.unpadded_vocab_size = lm_head_config["num_classes"] + self.proj_out = ParallelLMHead( + lm_head_config["num_classes"], + lm_head_config["hidden_size"], + quant_config=quant_config, + bias=True, + ) # NOTE: bias is True + logit_scale = getattr(lm_head_config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor( + self.unpadded_vocab_size, lm_head_config["num_classes"], logit_scale + ) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + encoder_outputs: list[torch.Tensor] | None = None, + **kwargs, + ) -> torch.Tensor: + if encoder_outputs is None: + encoder_outputs = [] + decoder_outputs = self.model( + input_ids=input_ids, + positions=positions, + encoder_outputs=encoder_outputs, + ) + + return decoder_outputs + + def get_language_model(self) -> torch.nn.Module: + # Required as part of SupportsMultiModal interface. + return self.model.decoder + + def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings: + # Required as part of SupportsMultiModal interface. + audio_input, seq_lens = self._parse_and_validate_audio_input(**kwargs) + + if hasattr(audio_input, "input_features"): + out = self.model.get_encoder_outputs(audio_input["input_features"]) + else: + out = self.model.get_encoder_outputs(audio_input, seq_lens) + + return out + + def _parse_and_validate_audio_input( + self, **kwargs: object + ) -> tuple[torch.Tensor, torch.Tensor]: + input_features = kwargs.pop("input_features", None) + length = kwargs.pop("length", None) + + if input_features is None: + raise ValueError("Audio features are required for CohereASR model.") + + if not isinstance(input_features, (torch.Tensor, list)): + raise ValueError( + f"Incorrect type of audio features. Got type: {type(input_features)}" + ) + + if isinstance(input_features, torch.Tensor): + seq_lens = length.reshape(-1) + else: + input_features = [ + feat.to(self.dtype).squeeze(0).transpose(1, 0) + for feat in input_features + ] + seq_lens = length.reshape(-1) + input_features = torch.nn.utils.rnn.pad_sequence( + input_features, batch_first=True, padding_value=0.0 + ) + input_features = input_features.transpose(1, 2) + + return input_features, seq_lens + + def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor: + logits = self.logits_processor(self.proj_out, hidden_states, self.proj_out.bias) + return logits + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + def transform(inputs): + name, loaded_weight = inputs + + if name.startswith("transf_decoder._decoder"): + name = name.replace("transf_decoder._decoder", "decoder") + if name.startswith("transf_decoder._embedding"): + name = name.replace("transf_decoder._embedding", "decoder.embedding") + if "second_sub_layer.query_net" in name: + name = name.replace( + "second_sub_layer.query_net", "second_sub_layer.q_proj" + ) + + if name in ["log_softmax.mlp.layer0.weight", "log_softmax.mlp.layer0.bias"]: + name = name.replace("log_softmax.mlp.layer0", "proj_out") + else: + name = "model." + name + + return name, loaded_weight + + loader = AutoWeightsLoader( + self, + skip_prefixes=[ + "model.preprocessor.featurizer.fb", + "model.preprocessor.featurizer.window", + ], + skip_substrs=["model.conv.batch_norm.num_batches_tracked"], + ) + + return loader.load_weights( + map(transform, weights), mapper=self.hf_to_vllm_mapper + ) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 51f370bcc..7e83af3fd 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -534,6 +534,10 @@ _MULTIMODAL_MODELS = { "VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501 "VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"), # noqa: E501 # [Encoder-decoder] + "CohereASRForConditionalGeneration": ( + "cohere_asr", + "CohereASRForConditionalGeneration", + ), "NemotronParseForConditionalGeneration": ( "nemotron_parse", "NemotronParseForConditionalGeneration", diff --git a/vllm/multimodal/processing/processor.py b/vllm/multimodal/processing/processor.py index 839128fbf..f26c17964 100644 --- a/vllm/multimodal/processing/processor.py +++ b/vllm/multimodal/processing/processor.py @@ -1682,6 +1682,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]): class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]): + skip_decoder_start_token: bool = False + @abstractmethod def create_encoder_prompt( self, diff --git a/vllm/renderers/base.py b/vllm/renderers/base.py index 1db6149b0..b468712ad 100644 --- a/vllm/renderers/base.py +++ b/vllm/renderers/base.py @@ -700,12 +700,20 @@ class BaseRenderer(ABC, Generic[_T]): enc_prompt = prompt["encoder_prompt"] dec_prompt = prompt["decoder_prompt"] + skip_decoder_start_token = False + if self.mm_processor is not None: + from vllm.multimodal.processing import EncDecMultiModalProcessor + + if isinstance(self.mm_processor, EncDecMultiModalProcessor): + skip_decoder_start_token = self.mm_processor.skip_decoder_start_token + return build_enc_dec_inputs( encoder_inputs=self._process_singleton(enc_prompt), decoder_inputs=( None if dec_prompt is None else self._process_singleton(dec_prompt) ), decoder_start_token_id=self.get_dec_start_token_id(), + skip_decoder_start_token=skip_decoder_start_token, ) def process_for_engine( diff --git a/vllm/transformers_utils/model_arch_config_convertor.py b/vllm/transformers_utils/model_arch_config_convertor.py index b01592aa3..26fc04042 100644 --- a/vllm/transformers_utils/model_arch_config_convertor.py +++ b/vllm/transformers_utils/model_arch_config_convertor.py @@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase: return model_arch_config +class CohereAsrModelArchConfigConvertor(ModelArchConfigConvertorBase): + def get_total_num_attention_heads(self) -> int: + return self.hf_text_config.transf_decoder["config_dict"]["num_attention_heads"] + + def get_head_size(self) -> int: + hidden_size = self.hf_text_config.transf_decoder["config_dict"]["hidden_size"] + num_attention_heads = self.hf_text_config.transf_decoder["config_dict"][ + "num_attention_heads" + ] + return hidden_size // num_attention_heads + + def get_total_num_kv_heads(self) -> int: + enc_num_kv_heads = self.hf_text_config.encoder["n_heads"] + dec_num_kv_heads = self.hf_text_config.transf_decoder["config_dict"][ + "num_attention_heads" + ] + assert enc_num_kv_heads == dec_num_kv_heads, ( + "Encoder and decoder must have the same number of kv heads" + ) + return enc_num_kv_heads + + class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase): def get_head_size(self) -> int: return 0 @@ -425,6 +447,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase): # hf_config.model_type -> convertor class MODEL_ARCH_CONFIG_CONVERTORS = { + "cohere_asr": CohereAsrModelArchConfigConvertor, "mamba": MambaModelArchConfigConvertor, "falcon_mamba": MambaModelArchConfigConvertor, "timm_wrapper": TerratorchModelArchConfigConvertor, diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index 9c393b700..fe34327d2 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -12,6 +12,7 @@ import importlib __all__ = [ "BagelProcessor", + "CohereASRProcessor", "DeepseekVLV2Processor", "Eagle2_5_VLProcessor", "FireRedASR2Processor", @@ -38,6 +39,7 @@ __all__ = [ _CLASS_TO_MODULE: dict[str, str] = { "BagelProcessor": "vllm.transformers_utils.processors.bagel", + "CohereASRProcessor": "vllm.transformers_utils.processors.cohere_asr", "DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2", "Eagle2_5_VLProcessor": "vllm.transformers_utils.processors.eagle2_5_vl", "FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2", diff --git a/vllm/transformers_utils/processors/cohere_asr.py b/vllm/transformers_utils/processors/cohere_asr.py new file mode 100644 index 000000000..f742074a4 --- /dev/null +++ b/vllm/transformers_utils/processors/cohere_asr.py @@ -0,0 +1,575 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import logging +import math +import random + +import librosa +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from transformers import AutoFeatureExtractor, AutoProcessor, BatchFeature +from transformers.feature_extraction_sequence_utils import ( + SequenceFeatureExtractor, +) +from transformers.processing_utils import ProcessorMixin + +logger = logging.getLogger(__name__) + +CONSTANT = 1e-5 +INF_VAL = 10000.0 + + +class FilterbankFeatures(nn.Module): + """Featurizer that converts wavs to Mel Spectrograms. + See AudioToMelSpectrogramPreprocessor for args. + """ + + window: torch.Tensor + fb: torch.Tensor + + def __init__( + self, + sample_rate=16000, + n_window_size=320, + n_window_stride=160, + window="hann", + normalize="per_feature", + n_fft=None, + preemph=0.97, + nfilt=64, + lowfreq=0, + highfreq=None, + log=True, + log_zero_guard_type="add", + log_zero_guard_value=2**-24, + dither=CONSTANT, + pad_to=16, + max_duration=30, + frame_splicing=1, + exact_pad=False, + pad_value=0, + mag_power=2.0, + use_grads=False, + rng=None, + nb_augmentation_prob=0.0, + nb_max_freq=4000, + mel_norm="slaney", + stft_exact_pad=False, + stft_conv=False, + device="cpu", + ): + super().__init__() + if stft_conv or stft_exact_pad: + logger.warning( + "Using torch_stft is deprecated and has been removed. " + "The values have been forcibly set to False for " + "FilterbankFeatures and AudioToMelSpectrogramPreprocessor. " + "Please set exact_pad to True as needed." + ) + if exact_pad and n_window_stride % 2 == 1: + raise NotImplementedError( + f"{self} received exact_pad == True, but hop_size was odd. " + "If audio_length % hop_size == 0, the returned spectrogram " + "would not be of length audio_length // hop_size. " + "Please use an even hop_size." + ) + self.log_zero_guard_value = log_zero_guard_value + if ( + n_window_size is None + or n_window_stride is None + or not isinstance(n_window_size, int) + or not isinstance(n_window_stride, int) + or n_window_size <= 0 + or n_window_stride <= 0 + ): + raise ValueError( + f"{self} got an invalid value for either n_window_size or " + f"n_window_stride. Both must be positive ints." + ) + + self.sample_rate = sample_rate + self.win_length = n_window_size + self.hop_length = n_window_stride + self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length)) + self.stft_pad_amount = ( + (self.n_fft - self.hop_length) // 2 if exact_pad else None + ) + self.exact_pad = exact_pad + self.sample_rate = sample_rate + self.max_duration = max_duration + + if exact_pad: + logger.info("STFT using exact pad") + torch_windows = { + "hann": torch.hann_window, + "hamming": torch.hamming_window, + "blackman": torch.blackman_window, + "bartlett": torch.bartlett_window, + "none": None, + } + window_fn = torch_windows.get(window) + window_tensor = ( + window_fn(self.win_length, periodic=False) if window_fn else None + ) + self.register_buffer("window", window_tensor) + + self.normalize = normalize + self.log = log + self.dither = dither + self.frame_splicing = frame_splicing + self.nfilt = nfilt + self.preemph = preemph + self.pad_to = pad_to + highfreq = highfreq or sample_rate / 2 + self.sample_rate = sample_rate + # disable pad min duration + # self.pad_min_duration = 1.0 + self.pad_min_duration = 0.0 + self.pad_direction = "both" + + filterbanks = torch.tensor( + librosa.filters.mel( + sr=sample_rate, + n_fft=self.n_fft, + n_mels=nfilt, + fmin=lowfreq, + fmax=highfreq, + norm=mel_norm, + ), + dtype=torch.float, + ).unsqueeze(0) + self.register_buffer("fb", filterbanks) + + # Calculate maximum sequence length + max_length = self.get_seq_len( + torch.tensor(max_duration * sample_rate, dtype=torch.float) + ) + max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0 + self.max_length = max_length + max_pad + self.pad_value = pad_value + self.mag_power = mag_power + + # We want to avoid taking the log of zero + # There are two options: either adding or clamping to a small value + if log_zero_guard_type not in ["add", "clamp"]: + raise ValueError( + f"{self} received {log_zero_guard_type} for the " + f"log_zero_guard_type parameter. It must be either 'add' or " + f"'clamp'." + ) + + self.use_grads = use_grads + if not use_grads: + self.forward = torch.no_grad()(self.forward) + self._rng = random.Random() if rng is None else rng + self.nb_augmentation_prob = nb_augmentation_prob + if self.nb_augmentation_prob > 0.0: + if nb_max_freq >= sample_rate / 2: + self.nb_augmentation_prob = 0.0 + else: + self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft) + + # log_zero_guard_value is the the small we want to use, we support + # an actual number, or "tiny", or "eps" + self.log_zero_guard_type = log_zero_guard_type + + assert self.window is not None + assert self.fb is not None + self.window = self.window.to(dtype=torch.bfloat16) + self.fb = self.fb.to(dtype=torch.bfloat16) + + self.generator = torch.Generator(device=device) + self.generator.manual_seed(0) + + @torch._dynamo.disable + def stft(self, x): + # disable autocast to get full range of stft values + with torch.amp.autocast(x.device.type, enabled=False): + return torch.stft( + x, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + center=not self.exact_pad, + window=self.window.to(dtype=torch.float, device=x.device), + return_complex=True, + pad_mode="constant", + ) + + def log_zero_guard_value_fn(self, x): + if isinstance(self.log_zero_guard_value, str): + if self.log_zero_guard_value == "tiny": + return torch.finfo(x.dtype).tiny + elif self.log_zero_guard_value == "eps": + return torch.finfo(x.dtype).eps + else: + raise ValueError( + f"{self} received {self.log_zero_guard_value} for the " + f"log_zero_guard_type parameter. It must be either a " + f"number, 'tiny', or 'eps'" + ) + else: + return self.log_zero_guard_value + + def get_seq_len(self, seq_len): + # Assuming that center is True is stft_pad_amount = 0 + pad_amount = ( + self.stft_pad_amount * 2 + if self.stft_pad_amount is not None + else self.n_fft // 2 * 2 + ) + seq_len = torch.floor_divide( + (seq_len + pad_amount - self.n_fft), self.hop_length + ) + return seq_len.to(dtype=torch.long) + + @property + def filter_banks(self): + return self.fb + + def splice_frames(self, x, frame_splicing): + """Stacks frames together across feature dim + + input is batch_size, feature_dim, num_frames + output is batch_size, feature_dim*frame_splicing, num_frames + + """ + seq = [x] + for n in range(1, frame_splicing): + seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2)) + return torch.cat(seq, dim=1) + + def normalize_batch(self, x, seq_len, normalize_type): + x_mean = None + x_std = None + if normalize_type == "per_feature": + batch_size = x.shape[0] + max_time = x.shape[2] + + # When doing stream capture to a graph, item() is not allowed + # because it calls cudaStreamSynchronize(). Therefore, we are + # sacrificing some error checking when running with cuda graphs. + # if ( + # torch.cuda.is_available() + # and not torch.cuda.is_current_stream_capturing() + # and torch.any(seq_len == 1).item() + # ): + # raise ValueError( + # "normalize_batch with `per_feature` normalize_type " + # "received a tensor of length 1. This will result in " + # "torch.std() returning nan. Make sure your audio length " + # "has enough samples for a single feature (ex. at least " + # "`hop_length` for Mel Spectrograms)." + # ) + time_steps = ( + torch.arange(max_time, device=x.device) + .unsqueeze(0) + .expand(batch_size, max_time) + ) + valid_mask = time_steps < seq_len.unsqueeze(1) + x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2) + x_mean_denominator = valid_mask.sum(axis=1) + x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1) + + # Subtract 1 in the denominator to correct for the bias. + x_std = torch.sqrt( + torch.sum( + torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0) + ** 2, + axis=2, + ) + / (x_mean_denominator.unsqueeze(1) - 1.0) + ) + x_std = x_std.masked_fill( + x_std.isnan(), 0.0 + ) # edge case: only 1 frame in denominator + # make sure x_std is not zero + x_std += CONSTANT + return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std + elif normalize_type == "all_features": + x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) + x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device) + for i in range(x.shape[0]): + x_mean[i] = x[i, :, : seq_len[i].item()].mean() + x_std[i] = x[i, :, : seq_len[i].item()].std() + # make sure x_std is not zero + x_std += CONSTANT + return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std + elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type: + x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device) + x_std = torch.tensor(normalize_type["fixed_std"], device=x.device) + return ( + (x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2)) + / x_std.view(x.shape[0], x.shape[1]).unsqueeze(2), + x_mean, + x_std, + ) + else: + return x, x_mean, x_std + + @torch.compile + def forward(self, x, seq_len, linear_spec=False): + if x.shape[1] < self.sample_rate * self.pad_min_duration: + pad_amount = int(self.sample_rate * self.pad_min_duration) - x.shape[1] + if self.pad_direction == "right": + x = F.pad(x, (0, pad_amount), value=self.pad_value) + elif self.pad_direction == "left": + x = F.pad(x, (pad_amount, 0), value=self.pad_value) + elif self.pad_direction == "both": + left_pad = pad_amount // 2 + right_pad = pad_amount - left_pad + x = F.pad(x, (left_pad, right_pad), value=self.pad_value) + else: + raise ValueError( + f"{self} received an invalid pad_direction: {self.pad_direction}. " + f"It must be one of 'left', 'right', or 'both'." + ) + seq_len = torch.tensor([x.shape[1]], dtype=torch.float, device=x.device) + + seq_len_time = seq_len + seq_len_unfixed = self.get_seq_len(seq_len) + + # fix for seq_len = 0 for streaming; if size was 0, it is always padded + # to 1, and normalizer fails + seq_len = torch.where( + seq_len == 0, torch.zeros_like(seq_len_unfixed), seq_len_unfixed + ) + + if self.stft_pad_amount is not None: + x = torch.nn.functional.pad( + x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "constant" + ).squeeze(1) + + # use dither for inference as well + if self.dither > 0: + x += self.dither * torch.randn( + x.shape, dtype=x.dtype, device=x.device, generator=self.generator + ) + + # do preemphasis + if self.preemph is not None: + timemask = torch.arange(x.shape[1], device=x.device).unsqueeze( + 0 + ) < seq_len_time.unsqueeze(1) + x = torch.cat( + (x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1 + ) + + x = x.masked_fill(~timemask, 0.0) + + x = self.stft(x) + + # torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude + # guard is needed for sqrt if grads are passed through + guard = 0 if not self.use_grads else CONSTANT + x = torch.view_as_real(x) + x = torch.sqrt(x.pow(2).sum(-1) + guard) + + # get power spectrum + if self.mag_power != 1.0: + x = x.pow(self.mag_power) + + # return plain spectrogram if required + if linear_spec: + return x, seq_len + + # disable autocast, otherwise it might be automatically casted to fp16 + # on fp16 compatible GPUs and get NaN values for input value of 65520 + with torch.amp.autocast(x.device.type, enabled=False): + # dot with filterbank energies + x = torch.matmul(self.fb.to(x.dtype), x) + + # log features if required + if self.log: + if self.log_zero_guard_type == "add": + x = torch.log(x + self.log_zero_guard_value_fn(x)) + elif self.log_zero_guard_type == "clamp": + x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x))) + else: + raise ValueError("log_zero_guard_type was not understood") + + # frame splicing if required + if self.frame_splicing > 1: + x = self.splice_frames(x, self.frame_splicing) + + # normalize if required + if self.normalize: + x, _, _ = self.normalize_batch(x, seq_len, normalize_type=self.normalize) + + # mask to zero any values beyond seq_len in batch, pad to multiple of + # `pad_to` (for efficiency) + max_len = x.size(-1) + mask = torch.arange(max_len, device=x.device) + mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1) + x = x.masked_fill( + mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value + ) + + del mask + pad_to = self.pad_to + if pad_to == "max": + x = nn.functional.pad( + x, (0, self.max_length - x.size(-1)), value=self.pad_value + ) + elif pad_to > 0: + pad_amt = x.size(-1) % pad_to + if pad_amt != 0: + x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value) + + return x, seq_len + + +class CohereASRFeatureExtractor(SequenceFeatureExtractor): + """HF-compatible feature extractor wrapping FilterbankFeatures.""" + + model_input_names = ["input_features"] + + def __init__( + self, + feature_size=64, + sampling_rate=16000, + padding_value=0.0, + max_duration=30, + n_window_size=320, + n_window_stride=160, + window="hann", + normalize="per_feature", + n_fft=None, + preemph=0.97, + lowfreq=0, + highfreq=None, + log=True, + log_zero_guard_type="add", + log_zero_guard_value=2**-24, + dither=CONSTANT, + pad_to=16, + frame_splicing=1, + exact_pad=False, + mag_power=2.0, + nb_augmentation_prob=0.0, + nb_max_freq=4000, + mel_norm="slaney", + stft_exact_pad=False, + stft_conv=False, + device="cpu", + **kwargs, + ): + super().__init__( + feature_size=feature_size, + sampling_rate=sampling_rate, + padding_value=padding_value, + **kwargs, + ) + self.max_duration = max_duration + self.hop_length = n_window_stride + self._device = torch.device(device) + self._fb_config = dict( + sample_rate=sampling_rate, + n_window_size=n_window_size, + n_window_stride=n_window_stride, + window=window, + normalize=normalize, + n_fft=n_fft, + preemph=preemph, + nfilt=feature_size, + lowfreq=lowfreq, + highfreq=highfreq, + log=log, + log_zero_guard_type=log_zero_guard_type, + log_zero_guard_value=log_zero_guard_value, + dither=dither, + pad_to=pad_to, + max_duration=max_duration, + frame_splicing=frame_splicing, + exact_pad=exact_pad, + pad_value=padding_value, + mag_power=mag_power, + nb_augmentation_prob=nb_augmentation_prob, + nb_max_freq=nb_max_freq, + mel_norm=mel_norm, + stft_exact_pad=stft_exact_pad, + stft_conv=stft_conv, + device=device, + ) + self._filterbank: FilterbankFeatures | None = None + + @property + def filterbank(self) -> FilterbankFeatures: + if self._filterbank is None: + fb = FilterbankFeatures(**self._fb_config) + fb.eval() + self._filterbank = fb.to(self._device) + return self._filterbank + + def get_seq_len(self, seq_len): + return self.filterbank.get_seq_len(seq_len) + + def __call__( + self, + raw_speech, + sampling_rate=None, + return_tensors=None, + **kwargs, + ) -> BatchFeature: + if isinstance(raw_speech, np.ndarray): + raw_speech = [raw_speech] + + seq_len = torch.tensor([s.shape[0] for s in raw_speech]) + + max_len = max(s.shape[0] for s in raw_speech) + padded = np.zeros((len(raw_speech), max_len), dtype=np.float32) + for i, s in enumerate(raw_speech): + padded[i, : s.shape[0]] = s + + audio_tensor = torch.from_numpy(padded).to(self._device) + seq_len = seq_len.to(self._device) + + with torch.no_grad(): + input_features, length = self.filterbank(audio_tensor, seq_len) + + result = BatchFeature( + {"input_features": input_features.cpu(), "length": length.cpu()} + ) + if return_tensors is not None: + result = result.convert_to_tensors(return_tensors) + return result + + +class CohereASRProcessor(ProcessorMixin): + """HF-compatible processor combining CohereASRFeatureExtractor and a + tokenizer.""" + + feature_extractor_class = "CohereASRFeatureExtractor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, feature_extractor, tokenizer): + super().__init__(feature_extractor, tokenizer) + + def __call__( + self, + text=None, + audio=None, + sampling_rate=None, + return_tensors=None, + **kwargs, + ): + if audio is not None: + result = self.feature_extractor( + audio, + sampling_rate=sampling_rate, + return_tensors=return_tensors, + ) + else: + result = BatchFeature() + + if text is not None: + text_inputs = self.tokenizer(text, return_tensors=return_tensors, **kwargs) + result["input_ids"] = text_inputs["input_ids"] + + return result + + +AutoFeatureExtractor.register("CohereASRFeatureExtractor", CohereASRFeatureExtractor) +AutoProcessor.register("CohereASRProcessor", CohereASRProcessor)