[Models] Cohere ASR (#35809)

Signed-off-by: Ekagra Ranjan <3116519+ekagra-ranjan@users.noreply.github.com>
This commit is contained in:
Ekagra Ranjan
2026-03-17 17:04:17 -04:00
committed by GitHub
parent 245758992e
commit b5ca9c3557
14 changed files with 2910 additions and 18 deletions

View File

@@ -70,6 +70,29 @@ def run_audioflamingo3(question: str, audio_count: int) -> ModelRequestData:
)
# CohereASR
def run_cohere_asr(question: str, audio_count: int) -> ModelRequestData:
assert audio_count == 1, "CohereASR only support single audio input per prompt"
# TODO (ekagra): add HF ckpt after asr release
model_name = "/host/engines/vllm/audio/2b-release"
prompt = (
"<|startofcontext|><|startoftranscript|>"
"<|emo:undefined|><|en|><|en|><|pnc|><|noitn|>"
"<|notimestamp|><|nodiarize|>"
)
engine_args = EngineArgs(
model=model_name,
limit_mm_per_prompt={"audio": audio_count},
trust_remote_code=True,
)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)
# MusicFlamingo
def run_musicflamingo(question: str, audio_count: int) -> ModelRequestData:
model_name = "nvidia/music-flamingo-2601-hf"
@@ -508,14 +531,15 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
model_example_map = {
"audioflamingo3": run_audioflamingo3,
"musicflamingo": run_musicflamingo,
"cohere_asr": run_cohere_asr,
"funaudiochat": run_funaudiochat,
"gemma3n": run_gemma3n,
"glmasr": run_glmasr,
"funaudiochat": run_funaudiochat,
"granite_speech": run_granite_speech,
"kimi_audio": run_kimi_audio,
"midashenglm": run_midashenglm,
"minicpmo": run_minicpmo,
"musicflamingo": run_musicflamingo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,

View File

@@ -19,8 +19,10 @@ import soundfile
import torch
from datasets import load_dataset
from evaluate import load
from transformers import AutoTokenizer
from vllm.tokenizers import get_tokenizer
from ....models.registry import HF_EXAMPLE_MODELS
from ....utils import RemoteOpenAIServer
@@ -64,8 +66,12 @@ async def bound_transcribe(sem, client, tokenizer, audio, reference):
async def process_dataset(model, client, data, concurrent_request):
sem = asyncio.Semaphore(concurrent_request)
# Load tokenizer once outside the loop
tokenizer = AutoTokenizer.from_pretrained(model)
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
tokenizer = get_tokenizer(
model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
)
# Warmup call as the first `librosa.load` server-side is quite slow.
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
@@ -144,20 +150,35 @@ def run_evaluation(
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
# NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize(
"model_config",
[
("openai/whisper-large-v3", 12.744980),
# TODO (ekagra): add HF ckpt after asr release
# ("/host/engines/vllm/audio/2b-release", 11.73),
],
)
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
@pytest.mark.parametrize(
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"]
)
# NOTE: Expected WER measured with equivalent hf.transformers args:
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
@pytest.mark.parametrize("expected_wer", [12.744980])
def test_wer_correctness(
model_name, dataset_repo, expected_wer, n_examples=-1, max_concurrent_request=None
model_config, dataset_repo, n_examples=-1, max_concurrent_request=None
):
model_name, expected_wer = model_config
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_name)
# TODO refactor to use `ASRDataset`
server_args = [
"--enforce-eager",
f"--tokenizer_mode={model_info.tokenizer_mode}",
]
if model_info.trust_remote_code:
server_args.append("--trust-remote-code")
with RemoteOpenAIServer(
model_name, ["--enforce-eager"], max_wait_seconds=480
model_name,
server_args,
) as remote_server:
dataset = load_hf_dataset(dataset_repo)
@@ -167,7 +188,14 @@ def test_wer_correctness(
client = remote_server.get_async_client()
wer = run_evaluation(
model_name, client, dataset, max_concurrent_request, n_examples
model_name,
client,
dataset,
max_concurrent_request,
n_examples,
)
print(f"Expected WER: {expected_wer}, Actual WER: {wer}")
if expected_wer:
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)

View File

@@ -1116,6 +1116,11 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer_mode="mistral",
),
# [Encoder-decoder]
"CohereASRForConditionalGeneration": _HfExamplesInfo(
"/host/engines/vllm/audio/2b-release",
trust_remote_code=True,
is_available_online=False, # TODO (ekagra): revert after asr release
),
"NemotronParseForConditionalGeneration": _HfExamplesInfo(
"nvidia/NVIDIA-Nemotron-Parse-v1.1", trust_remote_code=True
),

View File

@@ -3157,7 +3157,7 @@ class ASRDataset(HuggingFaceDataset):
**kwargs,
) -> list:
output_len = output_len if output_len is not None else self.DEFAULT_OUTPUT_LEN
if "openai" in tokenizer.name_or_path:
if "openai" in getattr(tokenizer, "name_or_path", ""):
prompt = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
else:
prompt = ""

View File

@@ -107,7 +107,7 @@ class TranscriptionRequest(OpenAIBaseModel):
stream_include_usage: bool | None = False
stream_continuous_usage_stats: bool | None = False
vllm_xargs: dict[str, str | int | float] | None = Field(
vllm_xargs: dict[str, str | int | float | bool] | None = Field(
default=None,
description=(
"Additional request parameters with string or "

View File

@@ -365,6 +365,7 @@ def build_enc_dec_inputs(
encoder_inputs: SingletonInputs,
decoder_inputs: SingletonInputs | None,
decoder_start_token_id: int,
skip_decoder_start_token: bool = False,
) -> EncoderDecoderInputs:
enc_inputs = _validate_enc_inputs(encoder_inputs)
@@ -396,10 +397,11 @@ def build_enc_dec_inputs(
else:
assert_never(enc_inputs)
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
dec_inputs_new["prompt_token_ids"],
decoder_start_token_id,
)
if not skip_decoder_start_token:
dec_inputs_new["prompt_token_ids"] = _prepare_decoder_input_ids_for_generation(
dec_inputs_new["prompt_token_ids"],
decoder_start_token_id,
)
if cache_salt := enc_inputs.get("cache_salt"):
dec_inputs_new["cache_salt"] = cache_salt

View File

@@ -261,6 +261,15 @@ class InputPreprocessor:
encoder_prompt = prompt["encoder_prompt"]
decoder_prompt = prompt["decoder_prompt"]
skip_decoder_start_token = False
if self.renderer.mm_processor is not None:
from vllm.multimodal.processing import EncDecMultiModalProcessor
if isinstance(self.renderer.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = (
self.renderer.mm_processor.skip_decoder_start_token
)
return build_enc_dec_inputs(
encoder_inputs=self._prompt_to_llm_inputs(
encoder_prompt,
@@ -275,6 +284,7 @@ class InputPreprocessor:
)
),
decoder_start_token_id=self.renderer.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
)
def _process_decoder_only_prompt(

File diff suppressed because it is too large Load Diff

View File

@@ -534,6 +534,10 @@ _MULTIMODAL_MODELS = {
"VoxtralForConditionalGeneration": ("voxtral", "VoxtralForConditionalGeneration"), # noqa: E501
"VoxtralRealtimeGeneration": ("voxtral_realtime", "VoxtralRealtimeGeneration"), # noqa: E501
# [Encoder-decoder]
"CohereASRForConditionalGeneration": (
"cohere_asr",
"CohereASRForConditionalGeneration",
),
"NemotronParseForConditionalGeneration": (
"nemotron_parse",
"NemotronParseForConditionalGeneration",

View File

@@ -1682,6 +1682,8 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
skip_decoder_start_token: bool = False
@abstractmethod
def create_encoder_prompt(
self,

View File

@@ -700,12 +700,20 @@ class BaseRenderer(ABC, Generic[_T]):
enc_prompt = prompt["encoder_prompt"]
dec_prompt = prompt["decoder_prompt"]
skip_decoder_start_token = False
if self.mm_processor is not None:
from vllm.multimodal.processing import EncDecMultiModalProcessor
if isinstance(self.mm_processor, EncDecMultiModalProcessor):
skip_decoder_start_token = self.mm_processor.skip_decoder_start_token
return build_enc_dec_inputs(
encoder_inputs=self._process_singleton(enc_prompt),
decoder_inputs=(
None if dec_prompt is None else self._process_singleton(dec_prompt)
),
decoder_start_token_id=self.get_dec_start_token_id(),
skip_decoder_start_token=skip_decoder_start_token,
)
def process_for_engine(

View File

@@ -300,6 +300,28 @@ class ModelArchConfigConvertorBase:
return model_arch_config
class CohereAsrModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_total_num_attention_heads(self) -> int:
return self.hf_text_config.transf_decoder["config_dict"]["num_attention_heads"]
def get_head_size(self) -> int:
hidden_size = self.hf_text_config.transf_decoder["config_dict"]["hidden_size"]
num_attention_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
return hidden_size // num_attention_heads
def get_total_num_kv_heads(self) -> int:
enc_num_kv_heads = self.hf_text_config.encoder["n_heads"]
dec_num_kv_heads = self.hf_text_config.transf_decoder["config_dict"][
"num_attention_heads"
]
assert enc_num_kv_heads == dec_num_kv_heads, (
"Encoder and decoder must have the same number of kv heads"
)
return enc_num_kv_heads
class MambaModelArchConfigConvertor(ModelArchConfigConvertorBase):
def get_head_size(self) -> int:
return 0
@@ -425,6 +447,7 @@ class LongCatFlashMTPModelArchConfigConvertor(ModelArchConfigConvertorBase):
# hf_config.model_type -> convertor class
MODEL_ARCH_CONFIG_CONVERTORS = {
"cohere_asr": CohereAsrModelArchConfigConvertor,
"mamba": MambaModelArchConfigConvertor,
"falcon_mamba": MambaModelArchConfigConvertor,
"timm_wrapper": TerratorchModelArchConfigConvertor,

View File

@@ -12,6 +12,7 @@ import importlib
__all__ = [
"BagelProcessor",
"CohereASRProcessor",
"DeepseekVLV2Processor",
"Eagle2_5_VLProcessor",
"FireRedASR2Processor",
@@ -38,6 +39,7 @@ __all__ = [
_CLASS_TO_MODULE: dict[str, str] = {
"BagelProcessor": "vllm.transformers_utils.processors.bagel",
"CohereASRProcessor": "vllm.transformers_utils.processors.cohere_asr",
"DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2",
"Eagle2_5_VLProcessor": "vllm.transformers_utils.processors.eagle2_5_vl",
"FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2",

View File

@@ -0,0 +1,575 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import logging
import math
import random
import librosa
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import AutoFeatureExtractor, AutoProcessor, BatchFeature
from transformers.feature_extraction_sequence_utils import (
SequenceFeatureExtractor,
)
from transformers.processing_utils import ProcessorMixin
logger = logging.getLogger(__name__)
CONSTANT = 1e-5
INF_VAL = 10000.0
class FilterbankFeatures(nn.Module):
"""Featurizer that converts wavs to Mel Spectrograms.
See AudioToMelSpectrogramPreprocessor for args.
"""
window: torch.Tensor
fb: torch.Tensor
def __init__(
self,
sample_rate=16000,
n_window_size=320,
n_window_stride=160,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
nfilt=64,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2**-24,
dither=CONSTANT,
pad_to=16,
max_duration=30,
frame_splicing=1,
exact_pad=False,
pad_value=0,
mag_power=2.0,
use_grads=False,
rng=None,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
mel_norm="slaney",
stft_exact_pad=False,
stft_conv=False,
device="cpu",
):
super().__init__()
if stft_conv or stft_exact_pad:
logger.warning(
"Using torch_stft is deprecated and has been removed. "
"The values have been forcibly set to False for "
"FilterbankFeatures and AudioToMelSpectrogramPreprocessor. "
"Please set exact_pad to True as needed."
)
if exact_pad and n_window_stride % 2 == 1:
raise NotImplementedError(
f"{self} received exact_pad == True, but hop_size was odd. "
"If audio_length % hop_size == 0, the returned spectrogram "
"would not be of length audio_length // hop_size. "
"Please use an even hop_size."
)
self.log_zero_guard_value = log_zero_guard_value
if (
n_window_size is None
or n_window_stride is None
or not isinstance(n_window_size, int)
or not isinstance(n_window_stride, int)
or n_window_size <= 0
or n_window_stride <= 0
):
raise ValueError(
f"{self} got an invalid value for either n_window_size or "
f"n_window_stride. Both must be positive ints."
)
self.sample_rate = sample_rate
self.win_length = n_window_size
self.hop_length = n_window_stride
self.n_fft = n_fft or 2 ** math.ceil(math.log2(self.win_length))
self.stft_pad_amount = (
(self.n_fft - self.hop_length) // 2 if exact_pad else None
)
self.exact_pad = exact_pad
self.sample_rate = sample_rate
self.max_duration = max_duration
if exact_pad:
logger.info("STFT using exact pad")
torch_windows = {
"hann": torch.hann_window,
"hamming": torch.hamming_window,
"blackman": torch.blackman_window,
"bartlett": torch.bartlett_window,
"none": None,
}
window_fn = torch_windows.get(window)
window_tensor = (
window_fn(self.win_length, periodic=False) if window_fn else None
)
self.register_buffer("window", window_tensor)
self.normalize = normalize
self.log = log
self.dither = dither
self.frame_splicing = frame_splicing
self.nfilt = nfilt
self.preemph = preemph
self.pad_to = pad_to
highfreq = highfreq or sample_rate / 2
self.sample_rate = sample_rate
# disable pad min duration
# self.pad_min_duration = 1.0
self.pad_min_duration = 0.0
self.pad_direction = "both"
filterbanks = torch.tensor(
librosa.filters.mel(
sr=sample_rate,
n_fft=self.n_fft,
n_mels=nfilt,
fmin=lowfreq,
fmax=highfreq,
norm=mel_norm,
),
dtype=torch.float,
).unsqueeze(0)
self.register_buffer("fb", filterbanks)
# Calculate maximum sequence length
max_length = self.get_seq_len(
torch.tensor(max_duration * sample_rate, dtype=torch.float)
)
max_pad = pad_to - (max_length % pad_to) if pad_to > 0 else 0
self.max_length = max_length + max_pad
self.pad_value = pad_value
self.mag_power = mag_power
# We want to avoid taking the log of zero
# There are two options: either adding or clamping to a small value
if log_zero_guard_type not in ["add", "clamp"]:
raise ValueError(
f"{self} received {log_zero_guard_type} for the "
f"log_zero_guard_type parameter. It must be either 'add' or "
f"'clamp'."
)
self.use_grads = use_grads
if not use_grads:
self.forward = torch.no_grad()(self.forward)
self._rng = random.Random() if rng is None else rng
self.nb_augmentation_prob = nb_augmentation_prob
if self.nb_augmentation_prob > 0.0:
if nb_max_freq >= sample_rate / 2:
self.nb_augmentation_prob = 0.0
else:
self._nb_max_fft_bin = int((nb_max_freq / sample_rate) * n_fft)
# log_zero_guard_value is the the small we want to use, we support
# an actual number, or "tiny", or "eps"
self.log_zero_guard_type = log_zero_guard_type
assert self.window is not None
assert self.fb is not None
self.window = self.window.to(dtype=torch.bfloat16)
self.fb = self.fb.to(dtype=torch.bfloat16)
self.generator = torch.Generator(device=device)
self.generator.manual_seed(0)
@torch._dynamo.disable
def stft(self, x):
# disable autocast to get full range of stft values
with torch.amp.autocast(x.device.type, enabled=False):
return torch.stft(
x,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
center=not self.exact_pad,
window=self.window.to(dtype=torch.float, device=x.device),
return_complex=True,
pad_mode="constant",
)
def log_zero_guard_value_fn(self, x):
if isinstance(self.log_zero_guard_value, str):
if self.log_zero_guard_value == "tiny":
return torch.finfo(x.dtype).tiny
elif self.log_zero_guard_value == "eps":
return torch.finfo(x.dtype).eps
else:
raise ValueError(
f"{self} received {self.log_zero_guard_value} for the "
f"log_zero_guard_type parameter. It must be either a "
f"number, 'tiny', or 'eps'"
)
else:
return self.log_zero_guard_value
def get_seq_len(self, seq_len):
# Assuming that center is True is stft_pad_amount = 0
pad_amount = (
self.stft_pad_amount * 2
if self.stft_pad_amount is not None
else self.n_fft // 2 * 2
)
seq_len = torch.floor_divide(
(seq_len + pad_amount - self.n_fft), self.hop_length
)
return seq_len.to(dtype=torch.long)
@property
def filter_banks(self):
return self.fb
def splice_frames(self, x, frame_splicing):
"""Stacks frames together across feature dim
input is batch_size, feature_dim, num_frames
output is batch_size, feature_dim*frame_splicing, num_frames
"""
seq = [x]
for n in range(1, frame_splicing):
seq.append(torch.cat([x[:, :, :n], x[:, :, n:]], dim=2))
return torch.cat(seq, dim=1)
def normalize_batch(self, x, seq_len, normalize_type):
x_mean = None
x_std = None
if normalize_type == "per_feature":
batch_size = x.shape[0]
max_time = x.shape[2]
# When doing stream capture to a graph, item() is not allowed
# because it calls cudaStreamSynchronize(). Therefore, we are
# sacrificing some error checking when running with cuda graphs.
# if (
# torch.cuda.is_available()
# and not torch.cuda.is_current_stream_capturing()
# and torch.any(seq_len == 1).item()
# ):
# raise ValueError(
# "normalize_batch with `per_feature` normalize_type "
# "received a tensor of length 1. This will result in "
# "torch.std() returning nan. Make sure your audio length "
# "has enough samples for a single feature (ex. at least "
# "`hop_length` for Mel Spectrograms)."
# )
time_steps = (
torch.arange(max_time, device=x.device)
.unsqueeze(0)
.expand(batch_size, max_time)
)
valid_mask = time_steps < seq_len.unsqueeze(1)
x_mean_numerator = torch.where(valid_mask.unsqueeze(1), x, 0.0).sum(axis=2)
x_mean_denominator = valid_mask.sum(axis=1)
x_mean = x_mean_numerator / x_mean_denominator.unsqueeze(1)
# Subtract 1 in the denominator to correct for the bias.
x_std = torch.sqrt(
torch.sum(
torch.where(valid_mask.unsqueeze(1), x - x_mean.unsqueeze(2), 0.0)
** 2,
axis=2,
)
/ (x_mean_denominator.unsqueeze(1) - 1.0)
)
x_std = x_std.masked_fill(
x_std.isnan(), 0.0
) # edge case: only 1 frame in denominator
# make sure x_std is not zero
x_std += CONSTANT
return (x - x_mean.unsqueeze(2)) / x_std.unsqueeze(2), x_mean, x_std
elif normalize_type == "all_features":
x_mean = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
x_std = torch.zeros(seq_len.shape, dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
x_mean[i] = x[i, :, : seq_len[i].item()].mean()
x_std[i] = x[i, :, : seq_len[i].item()].std()
# make sure x_std is not zero
x_std += CONSTANT
return (x - x_mean.view(-1, 1, 1)) / x_std.view(-1, 1, 1), x_mean, x_std
elif "fixed_mean" in normalize_type and "fixed_std" in normalize_type:
x_mean = torch.tensor(normalize_type["fixed_mean"], device=x.device)
x_std = torch.tensor(normalize_type["fixed_std"], device=x.device)
return (
(x - x_mean.view(x.shape[0], x.shape[1]).unsqueeze(2))
/ x_std.view(x.shape[0], x.shape[1]).unsqueeze(2),
x_mean,
x_std,
)
else:
return x, x_mean, x_std
@torch.compile
def forward(self, x, seq_len, linear_spec=False):
if x.shape[1] < self.sample_rate * self.pad_min_duration:
pad_amount = int(self.sample_rate * self.pad_min_duration) - x.shape[1]
if self.pad_direction == "right":
x = F.pad(x, (0, pad_amount), value=self.pad_value)
elif self.pad_direction == "left":
x = F.pad(x, (pad_amount, 0), value=self.pad_value)
elif self.pad_direction == "both":
left_pad = pad_amount // 2
right_pad = pad_amount - left_pad
x = F.pad(x, (left_pad, right_pad), value=self.pad_value)
else:
raise ValueError(
f"{self} received an invalid pad_direction: {self.pad_direction}. "
f"It must be one of 'left', 'right', or 'both'."
)
seq_len = torch.tensor([x.shape[1]], dtype=torch.float, device=x.device)
seq_len_time = seq_len
seq_len_unfixed = self.get_seq_len(seq_len)
# fix for seq_len = 0 for streaming; if size was 0, it is always padded
# to 1, and normalizer fails
seq_len = torch.where(
seq_len == 0, torch.zeros_like(seq_len_unfixed), seq_len_unfixed
)
if self.stft_pad_amount is not None:
x = torch.nn.functional.pad(
x.unsqueeze(1), (self.stft_pad_amount, self.stft_pad_amount), "constant"
).squeeze(1)
# use dither for inference as well
if self.dither > 0:
x += self.dither * torch.randn(
x.shape, dtype=x.dtype, device=x.device, generator=self.generator
)
# do preemphasis
if self.preemph is not None:
timemask = torch.arange(x.shape[1], device=x.device).unsqueeze(
0
) < seq_len_time.unsqueeze(1)
x = torch.cat(
(x[:, 0].unsqueeze(1), x[:, 1:] - self.preemph * x[:, :-1]), dim=1
)
x = x.masked_fill(~timemask, 0.0)
x = self.stft(x)
# torch stft returns complex tensor (of shape [B,N,T]); so convert to magnitude
# guard is needed for sqrt if grads are passed through
guard = 0 if not self.use_grads else CONSTANT
x = torch.view_as_real(x)
x = torch.sqrt(x.pow(2).sum(-1) + guard)
# get power spectrum
if self.mag_power != 1.0:
x = x.pow(self.mag_power)
# return plain spectrogram if required
if linear_spec:
return x, seq_len
# disable autocast, otherwise it might be automatically casted to fp16
# on fp16 compatible GPUs and get NaN values for input value of 65520
with torch.amp.autocast(x.device.type, enabled=False):
# dot with filterbank energies
x = torch.matmul(self.fb.to(x.dtype), x)
# log features if required
if self.log:
if self.log_zero_guard_type == "add":
x = torch.log(x + self.log_zero_guard_value_fn(x))
elif self.log_zero_guard_type == "clamp":
x = torch.log(torch.clamp(x, min=self.log_zero_guard_value_fn(x)))
else:
raise ValueError("log_zero_guard_type was not understood")
# frame splicing if required
if self.frame_splicing > 1:
x = self.splice_frames(x, self.frame_splicing)
# normalize if required
if self.normalize:
x, _, _ = self.normalize_batch(x, seq_len, normalize_type=self.normalize)
# mask to zero any values beyond seq_len in batch, pad to multiple of
# `pad_to` (for efficiency)
max_len = x.size(-1)
mask = torch.arange(max_len, device=x.device)
mask = mask.repeat(x.size(0), 1) >= seq_len.unsqueeze(1)
x = x.masked_fill(
mask.unsqueeze(1).type(torch.bool).to(device=x.device), self.pad_value
)
del mask
pad_to = self.pad_to
if pad_to == "max":
x = nn.functional.pad(
x, (0, self.max_length - x.size(-1)), value=self.pad_value
)
elif pad_to > 0:
pad_amt = x.size(-1) % pad_to
if pad_amt != 0:
x = nn.functional.pad(x, (0, pad_to - pad_amt), value=self.pad_value)
return x, seq_len
class CohereASRFeatureExtractor(SequenceFeatureExtractor):
"""HF-compatible feature extractor wrapping FilterbankFeatures."""
model_input_names = ["input_features"]
def __init__(
self,
feature_size=64,
sampling_rate=16000,
padding_value=0.0,
max_duration=30,
n_window_size=320,
n_window_stride=160,
window="hann",
normalize="per_feature",
n_fft=None,
preemph=0.97,
lowfreq=0,
highfreq=None,
log=True,
log_zero_guard_type="add",
log_zero_guard_value=2**-24,
dither=CONSTANT,
pad_to=16,
frame_splicing=1,
exact_pad=False,
mag_power=2.0,
nb_augmentation_prob=0.0,
nb_max_freq=4000,
mel_norm="slaney",
stft_exact_pad=False,
stft_conv=False,
device="cpu",
**kwargs,
):
super().__init__(
feature_size=feature_size,
sampling_rate=sampling_rate,
padding_value=padding_value,
**kwargs,
)
self.max_duration = max_duration
self.hop_length = n_window_stride
self._device = torch.device(device)
self._fb_config = dict(
sample_rate=sampling_rate,
n_window_size=n_window_size,
n_window_stride=n_window_stride,
window=window,
normalize=normalize,
n_fft=n_fft,
preemph=preemph,
nfilt=feature_size,
lowfreq=lowfreq,
highfreq=highfreq,
log=log,
log_zero_guard_type=log_zero_guard_type,
log_zero_guard_value=log_zero_guard_value,
dither=dither,
pad_to=pad_to,
max_duration=max_duration,
frame_splicing=frame_splicing,
exact_pad=exact_pad,
pad_value=padding_value,
mag_power=mag_power,
nb_augmentation_prob=nb_augmentation_prob,
nb_max_freq=nb_max_freq,
mel_norm=mel_norm,
stft_exact_pad=stft_exact_pad,
stft_conv=stft_conv,
device=device,
)
self._filterbank: FilterbankFeatures | None = None
@property
def filterbank(self) -> FilterbankFeatures:
if self._filterbank is None:
fb = FilterbankFeatures(**self._fb_config)
fb.eval()
self._filterbank = fb.to(self._device)
return self._filterbank
def get_seq_len(self, seq_len):
return self.filterbank.get_seq_len(seq_len)
def __call__(
self,
raw_speech,
sampling_rate=None,
return_tensors=None,
**kwargs,
) -> BatchFeature:
if isinstance(raw_speech, np.ndarray):
raw_speech = [raw_speech]
seq_len = torch.tensor([s.shape[0] for s in raw_speech])
max_len = max(s.shape[0] for s in raw_speech)
padded = np.zeros((len(raw_speech), max_len), dtype=np.float32)
for i, s in enumerate(raw_speech):
padded[i, : s.shape[0]] = s
audio_tensor = torch.from_numpy(padded).to(self._device)
seq_len = seq_len.to(self._device)
with torch.no_grad():
input_features, length = self.filterbank(audio_tensor, seq_len)
result = BatchFeature(
{"input_features": input_features.cpu(), "length": length.cpu()}
)
if return_tensors is not None:
result = result.convert_to_tensors(return_tensors)
return result
class CohereASRProcessor(ProcessorMixin):
"""HF-compatible processor combining CohereASRFeatureExtractor and a
tokenizer."""
feature_extractor_class = "CohereASRFeatureExtractor"
tokenizer_class = "AutoTokenizer"
def __init__(self, feature_extractor, tokenizer):
super().__init__(feature_extractor, tokenizer)
def __call__(
self,
text=None,
audio=None,
sampling_rate=None,
return_tensors=None,
**kwargs,
):
if audio is not None:
result = self.feature_extractor(
audio,
sampling_rate=sampling_rate,
return_tensors=return_tensors,
)
else:
result = BatchFeature()
if text is not None:
text_inputs = self.tokenizer(text, return_tensors=return_tensors, **kwargs)
result["input_ids"] = text_inputs["input_ids"]
return result
AutoFeatureExtractor.register("CohereASRFeatureExtractor", CohereASRFeatureExtractor)
AutoProcessor.register("CohereASRProcessor", CohereASRProcessor)