Files
vllm/vllm/model_executor/models/cohere_asr.py
2026-03-25 16:13:51 -07:00

2219 lines
78 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Literal
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import MultiModalDataDict, PromptType, TextPrompt
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.attention import (
Attention,
CrossAttention,
)
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (
MultiModalFieldConfig,
MultiModalKwargsItems,
)
from vllm.multimodal.parse import (
AudioProcessorItems,
MultiModalDataItems,
MultiModalDataParser,
)
from vllm.multimodal.processing import (
BaseDummyInputsBuilder,
BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement,
PromptUpdate,
)
from vllm.renderers import TokenizeParams
from vllm.transformers_utils.processors.cohere_asr import (
INF_VAL,
CohereASRFeatureExtractor,
CohereASRProcessor,
)
from vllm.v1.attention.backend import (
AttentionType,
)
from .interfaces import (
MultiModalEmbeddings,
SupportsMultiModal,
SupportsTranscription,
)
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
logger = init_logger(__name__)
# From https://platform.openai.com/docs/guides/speech-to-text/supported-languages
ISO639_1_SUPPORTED_LANGS = {
"en": "English",
"fr": "French",
"de": "German",
"es": "Spanish",
"pt": "Portuguese",
"it": "Italian",
"nl": "Dutch",
"pl": "Polish",
"el": "Greek",
"ar": "Arabic",
"ko": "Korean",
"ja": "Japanese",
"vi": "Vietnamese",
"zh": "Chinese",
}
class CohereASRAttention(nn.Module):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
attn_type: AttentionType = AttentionType.DECODER,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.embed_dim = embed_dim
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if self.total_num_heads >= tp_size:
# Number of heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_heads % tp_size == 0
else:
# Number of heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_heads == 0
self.num_kv_heads = max(1, self.total_num_heads // tp_size)
self.head_dim = self.embed_dim // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.attn_type = attn_type
if (self.head_dim * num_heads) != self.embed_dim:
raise ValueError(
f"embed_dim must be divisible by num_heads (got `embed_dim`: "
f"{self.embed_dim} and `num_heads`: {num_heads})."
)
self.scaling = self.head_dim**-0.5
self._init_qkv(embed_dim, bias, quant_config, prefix=prefix)
self.out_projection = RowParallelLinear(
input_size=embed_dim,
output_size=embed_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.out_projection",
)
if attn_type == AttentionType.ENCODER:
raise NotImplementedError(
"CohereASRAttention does not support Encoder Self-Attention yet."
)
elif self.attn_type == AttentionType.ENCODER_DECODER:
self.attn = CrossAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
else: # AttentionType.DECODER (regular decoder self-attention)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
attn_type=self.attn_type,
)
def _init_qkv(
self,
embed_dim: int,
bias: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
attn_output = self.attn(q, k, v)
output, _ = self.out_projection(attn_output)
return output
class CohereASRCrossAttention(CohereASRAttention):
def __init__(
self,
embed_dim: int,
num_heads: int,
bias: bool = True,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__(
embed_dim=embed_dim,
num_heads=num_heads,
bias=bias,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
attn_type=AttentionType.ENCODER_DECODER,
)
def _init_qkv(
self,
embed_dim: int,
bias: bool = True,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
self.q_proj = ColumnParallelLinear(
input_size=embed_dim,
output_size=embed_dim,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
)
self.kv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
total_num_heads=0,
total_num_kv_heads=self.total_num_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.kv_proj",
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
) -> torch.Tensor:
q, _ = self.q_proj(hidden_states)
# Encoder hidden states are only computed once during prefill phase.
# Afterwards, the keys and values should be available in the kv-cache.
if encoder_hidden_states is not None:
kv, _ = self.kv_proj(encoder_hidden_states)
k, v = kv.split([self.kv_size, self.kv_size], dim=-1)
else:
k = v = None
attn_output = self.attn(q, k, v)
output, _ = self.out_projection(attn_output)
return output
# ----- Decoder START -----
class CohereASRMLP(nn.Module):
def __init__(
self,
embed_dim: int,
ffn_dim: int,
act_fn: str,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
self.activation_fn = get_act_fn(act_fn)
self.dense_in = ColumnParallelLinear(
input_size=embed_dim,
output_size=ffn_dim,
quant_config=quant_config,
prefix=f"{prefix}.fc1",
)
self.dense_out = RowParallelLinear(
input_size=ffn_dim,
output_size=embed_dim,
quant_config=quant_config,
prefix=f"{prefix}.fc2",
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.dense_in(hidden_states)
hidden_states = self.activation_fn(hidden_states)
hidden_states, _ = self.dense_out(hidden_states)
return hidden_states
class FixedPositionalEncoding(nn.Module):
"""
Fixed positional encoding (embedding layer) from sine and cosine functions
of different frequencies according to https://arxiv.org/abs/1706.03762
Args:
hidden_size: size of the embeddings in the model, also known as d_model
max_sequence_length: maximum allowed length of the input sequence
"""
def __init__(self, hidden_size: int, max_sequence_length: int = 512) -> None:
super().__init__()
self._hidden_size = hidden_size
self._max_sequence_length = max_sequence_length
self._build_pos_enc(
hidden_size=self._hidden_size, max_sequence_length=self._max_sequence_length
)
def _build_pos_enc(self, hidden_size: int, max_sequence_length: int) -> None:
"""Builds/replaces pre-computed positional encoding."""
pos_enc = torch.zeros(max_sequence_length, hidden_size)
position = torch.arange(0.0, max_sequence_length).unsqueeze(1)
coef = -math.log(10000.0) / hidden_size
div_term = torch.exp(coef * torch.arange(0.0, hidden_size, 2))
pos_enc[:, 0::2] = torch.sin(position * div_term)
pos_enc[:, 1::2] = torch.cos(position * div_term)
pos_enc.div_(math.sqrt(hidden_size))
self.register_buffer("pos_enc", pos_enc)
def forward(self, position_ids: torch.Tensor) -> torch.Tensor:
embeddings = torch.embedding(self.pos_enc, position_ids)
return embeddings
class CohereASRDecoderLayer(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config.transf_decoder["config_dict"]
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.hidden_dim = config.get("hidden_size")
self.ffn_dim = config.get("inner_size")
self.act_fn = config.get("hidden_act")
self.num_heads = config.get("num_attention_heads")
# self_attn
self.layer_norm_1 = nn.LayerNorm(self.hidden_dim)
self.first_sub_layer = CohereASRAttention(
embed_dim=self.hidden_dim,
num_heads=self.num_heads,
attn_type=AttentionType.DECODER,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.first_sub_layer",
)
# cross attn to attend to encoder
self.layer_norm_2 = nn.LayerNorm(self.hidden_dim)
self.second_sub_layer = CohereASRCrossAttention(
embed_dim=self.hidden_dim,
num_heads=self.num_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.second_sub_layer",
)
self.layer_norm_3 = nn.LayerNorm(self.hidden_dim)
self.third_sub_layer = CohereASRMLP(
embed_dim=self.hidden_dim,
ffn_dim=self.ffn_dim,
act_fn=self.act_fn,
quant_config=quant_config,
prefix=f"{prefix}.third_sub_layer",
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.layer_norm_1(hidden_states)
hidden_states = self.first_sub_layer(hidden_states=hidden_states)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm_2(hidden_states)
hidden_states = self.second_sub_layer(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states = self.layer_norm_3(hidden_states)
hidden_states = self.third_sub_layer(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
class TransformerEmbedding(nn.Module):
def __init__(
self,
vocab_size: int,
hidden_size: int,
max_target_positions: int,
padding_idx: int,
) -> None:
super().__init__()
self.token_embedding = nn.Embedding(vocab_size, hidden_size, padding_idx)
self.position_embedding = FixedPositionalEncoding(
hidden_size=hidden_size,
max_sequence_length=max_target_positions,
)
self.layer_norm = nn.LayerNorm(hidden_size)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
inputs_embeds = self.token_embedding(input_ids)
positions = self.position_embedding(positions)
embeddings = inputs_embeds + positions
embeddings = self.layer_norm(embeddings)
return embeddings
@support_torch_compile(dynamic_arg_dims={"input_ids": 0, "positions": -1})
class CohereASRDecoder(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.padding_idx = 2
config_dict = config.transf_decoder["config_dict"]
self.max_target_positions = config_dict.get("max_sequence_length")
self.hidden_size = config_dict.get("hidden_size")
self.num_decoder_layers = config_dict.get("num_layers")
self.vocab_size = config.head["num_classes"]
self.embedding = TransformerEmbedding(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
max_target_positions=self.max_target_positions,
padding_idx=self.padding_idx,
)
self.start_layer, self.end_layer, self.layers = make_layers(
self.num_decoder_layers,
lambda prefix: CohereASRDecoderLayer(
vllm_config=vllm_config, prefix=f"{prefix}.layers"
),
prefix=f"{prefix}.layers",
)
self.final_layer_norm = nn.LayerNorm(self.hidden_size)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
) -> torch.Tensor:
hidden_states = self.get_input_embeddings(input_ids, positions)
for decoder_layer in self.layers:
hidden_states = decoder_layer(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states
def get_input_embeddings(
self, input_ids: torch.Tensor, positions: torch.Tensor
) -> torch.Tensor:
return self.embedding(input_ids, positions)
# ----- Decoder END -----
# ----- Encoder START -----
class MaskedConvSequential(nn.Sequential):
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
x = x.unsqueeze(1) # (batch, 1, time, features)
current_lengths = lengths.clone().float()
mask = self._create_mask(x, current_lengths.long())
# Process through each layer with mask propagation
for i, layer in enumerate(self):
# Apply current mask before layer
x = self.apply_channel_mask(x, mask)
# Apply layer
x = layer(x)
# Update lengths for stride operations with proper padding
if hasattr(layer, "stride") and layer.stride != (1, 1):
if hasattr(layer, "_left_padding"):
padding = (
layer._left_padding,
layer._right_padding,
) # CausalConv2D
else:
padding = layer.padding
current_lengths = self.calculate_conv_output_size(
current_lengths, layer.kernel_size[0], layer.stride[0], padding
)
mask = self._create_mask(x, current_lengths.long())
# Final masking
x = self.apply_channel_mask(x, mask)
return x, current_lengths.long()
def _create_mask(self, tensor: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
"""Create broadcastable mask from per-sample lengths.
Returns a (B, 1, T, 1) mask that broadcasts over channels and
features without materializing a full (B, C, T, F) tensor.
"""
batch_size, channels, time, features = tensor.shape
time_mask = torch.arange(time, device=tensor.device).expand(
batch_size, time
) < lengths.unsqueeze(1)
return time_mask.to(tensor.dtype).unsqueeze(1).unsqueeze(-1)
def apply_channel_mask(
self, tensor: torch.Tensor, mask: torch.Tensor
) -> torch.Tensor:
"""Apply mask in-place via broadcasting.
tensor: (B, C, T, F), mask: (B, 1, T, 1)
"""
tensor.mul_(mask)
return tensor
def calculate_conv_output_size(
self,
input_size: torch.Tensor,
kernel_size: int,
stride: int,
padding: tuple[int, int],
):
"""Calculate exact output size after convolution."""
return (input_size + padding[0] + padding[1] - kernel_size) // stride + 1
class ConvSubsampling(nn.Module):
def __init__(
self,
subsampling: str,
subsampling_factor: int,
feat_in: int,
feat_out: int,
conv_channels: int,
subsampling_conv_chunking_factor: int = 1,
activation: nn.Module | None = None,
is_causal: bool = False,
) -> None:
super().__init__()
if activation is None:
activation = nn.ReLU()
if subsampling_factor % 2 != 0:
raise ValueError("Sampling factor should be a multiply of 2!")
self._sampling_num = int(math.log(subsampling_factor, 2))
if (
subsampling_conv_chunking_factor != -1
and subsampling_conv_chunking_factor != 1
and subsampling_conv_chunking_factor % 2 != 0
):
raise ValueError(
"subsampling_conv_chunking_factor should be -1, 1, or a power of 2"
)
in_channels = 1
layers = []
assert subsampling == "dw_striding"
self._stride = 2
self._kernel_size = 3
self._ceil_mode = False
assert not is_causal
self._left_padding = (self._kernel_size - 1) // 2
self._right_padding = (self._kernel_size - 1) // 2
# Layer 1
# [1, T, num_melspec] -> [conv_channels, T//2, num_melspec//2]
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
)
)
in_channels = conv_channels
layers.append(activation)
for i in range(self._sampling_num - 1):
# [conv_channels, T//2^i, num_melspec//2^i] ->
# [conv_channels, T//2^(i+1), num_melspec//2^(i+1)]
# depthwise conv
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=self._kernel_size,
stride=self._stride,
padding=self._left_padding,
groups=in_channels,
)
)
# [conv_channels, T//2^(i+1), num_melspec//2^(i+1)]
# -> [conv_channels, T//2^(i+1), num_melspec//2^(i+1)]
# pointwise conv
layers.append(
torch.nn.Conv2d(
in_channels=in_channels,
out_channels=conv_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
)
)
layers.append(activation)
in_channels = conv_channels
in_length = torch.tensor(feat_in, dtype=torch.float)
out_length = self.calc_length(
lengths=in_length,
all_paddings=self._left_padding + self._right_padding,
kernel_size=self._kernel_size,
stride=self._stride,
ceil_mode=self._ceil_mode,
repeat_num=self._sampling_num,
)
# reshape:
# [conv_channels, T//sub_factor, num_melspec//sub_factor]
# -> [T//sub_factor, conv_channels * (num_melspec//sub_factor)]
# mlp:
# [T//sub_factor, conv_channels * (num_melspec//sub_factor)]
# -> [T//sub_factor, feat_out]
self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out)
self.conv2d_subsampling = True
self.conv = MaskedConvSequential(*layers)
def calc_length(
self,
lengths: torch.Tensor,
all_paddings: int,
kernel_size: int,
stride: int,
ceil_mode: bool,
repeat_num: int = 1,
) -> torch.Tensor:
"""Calculates the output length of a Tensor passed
through a convolution or max pooling layer"""
add_pad: float = all_paddings - kernel_size
one: float = 1.0
for i in range(repeat_num):
lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one
lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
return lengths.to(dtype=torch.int)
def forward(
self, x: torch.Tensor, lengths: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
x, lengths = self.conv(x, lengths)
if self.conv2d_subsampling:
b, c, t, f = x.size()
x = self.out(x.transpose(1, 2).reshape(b, t, -1))
# Transpose to Channel Last mode
else:
x = x.transpose(1, 2)
return x, lengths
class PositionalEncoding(torch.nn.Module):
"""Fixed sinusoidal positional encoding.
Args:
d_model (int): embedding dim
max_len (int): maximum input length
xscale (bool): whether to scale the input by sqrt(d_model)
"""
def __init__(
self, d_model: int, max_len: int = 5000, xscale: float | None = None
) -> None:
super().__init__()
self.d_model = d_model
self.xscale = xscale
self.max_len = max_len
def create_pe(self, positions: torch.Tensor, dtype: torch.dtype) -> None:
pos_length = positions.size(0)
pe = torch.zeros(pos_length, self.d_model, device=positions.device)
div_term = torch.exp(
torch.arange(
0, self.d_model, 2, dtype=torch.float32, device=positions.device
)
* -(math.log(10000.0) / self.d_model)
)
pe[:, 0::2] = torch.sin(positions * div_term)
pe[:, 1::2] = torch.cos(positions * div_term)
pe = pe.unsqueeze(0).to(dtype)
if hasattr(self, "pe"):
self.pe = pe
else:
self.register_buffer("pe", pe, persistent=False)
def forward(
self, x: torch.Tensor, cache_len: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
"""Adds positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, feature_size)
cache_len (int): the size of the cache which is used to shift positions
Returns:
x+pos_emb (torch.Tensor): Its shape is (batch, time, feature_size)
pos_emb (torch.Tensor): Its shape is (1, time, feature_size)
"""
input_len = x.size(1) + cache_len
if self.xscale:
x = x * self.xscale
pos_emb = self.pe[:, :input_len]
x = x + pos_emb
return x, pos_emb
class RelPositionalEncoding(PositionalEncoding):
"""Relative positional encoding for TransformerXL's layers
See : Appendix B in https://arxiv.org/abs/1901.02860
Args:
d_model (int): embedding dim
max_len (int): maximum input length
xscale (bool): whether to scale the input by sqrt(d_model)
"""
def extend_pe(self, length: int, device: torch.device, dtype: torch.dtype) -> None:
"""Reset and extend the positional encodings if needed."""
needed_size = 2 * length - 1
if hasattr(self, "pe") and self.pe.size(1) >= needed_size:
return
positions = torch.arange(
length - 1, -length, -1, dtype=torch.float32, device=device
).unsqueeze(1)
self.create_pe(positions=positions, dtype=dtype)
def forward(
self, x: torch.Tensor, cache_len: int = 0
) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute positional encoding.
Args:
x (torch.Tensor): Input. Its shape is (batch, time, feature_size)
cache_len (int): the size of the cache which is used to shift positions
Returns:
x (torch.Tensor): Its shape is (batch, time, feature_size)
pos_emb (torch.Tensor): Its shape is (1, time, feature_size)
"""
if self.xscale:
x = x * self.xscale
input_len = x.size(1) + cache_len
center_pos = self.pe.size(1) // 2 + 1
start_pos = center_pos - input_len
end_pos = center_pos + input_len - 1
pos_emb = self.pe[:, start_pos:end_pos]
return x, pos_emb
class Swish(nn.SiLU):
"""
Swish activation function introduced in 'https://arxiv.org/abs/1710.05941'
Mathematically identical to SiLU. See note in nn.SiLU for references.
"""
class ConformerFeedForward(nn.Module):
"""
feed-forward module of Conformer model.
use_bias (bool): Apply bias to all Linear and Conv1d
layers to improve activation flow and stabilize
training of huge models.
"""
def __init__(
self,
d_model: int,
d_ff: int,
activation: nn.Module | None = None,
use_bias: bool = True,
) -> None:
super().__init__()
if activation is None:
activation = Swish()
self.linear1 = nn.Linear(d_model, d_ff, bias=use_bias)
self.activation = activation
self.linear2 = nn.Linear(d_ff, d_model, bias=use_bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear1(x)
x = self.activation(x)
x = self.linear2(x)
return x
class CausalConv1D(nn.Conv1d):
"""
A causal version of nn.Conv1d where each step would
have limited access to locations on its right or left.
All arguments are the same as nn.Conv1d except padding.
If padding is set None, then paddings are set
automatically to make it a causal convolution where
each location would not see any steps on its right.
If padding is set as a list (size of 2), then
padding[0] would be used as left padding and
padding[1] as right padding. It would make it possible
to control the number of steps to be accessible on the
right and left. This mode is not supported when
stride > 1. padding[0]+padding[1] should be equal to
(kernel_size - 1).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
padding: str | int = 0,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = "zeros",
device=None,
dtype=None,
) -> None:
if padding is None:
self._left_padding = kernel_size - 1
self._right_padding = stride - 1
else:
if stride != 1 and padding != kernel_size - 1:
raise ValueError("No striding allowed for non-symmetric convolutions!")
if isinstance(padding, int):
self._left_padding = padding
self._right_padding = padding
elif (
isinstance(padding, list)
and len(padding) == 2
and padding[0] + padding[1] == kernel_size - 1
):
self._left_padding = padding[0]
self._right_padding = padding[1]
else:
raise ValueError(f"Invalid padding param: {padding}!")
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
device=device,
dtype=dtype,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, pad=(self._left_padding, self._right_padding))
return super().forward(x)
class ConformerConvolution(nn.Module):
"""The convolution module for the Conformer model.
Args:
d_model (int): hidden dimension
kernel_size (int): kernel size for depthwise convolution
pointwise_activation (str): name of the activation
function to be used for the pointwise conv.
Note that Conformer uses a special key `glu_`
which is treated as the original default from
the paper.
use_bias (bool): Use bias in all Linear and Conv1d
layers to improve activation flow and stabilize
training of huge models. Defaults to True
"""
def __init__(
self,
d_model: int,
kernel_size: int,
norm_type: str = "batch_norm",
conv_context_size: int | None = None,
pointwise_activation: str = "glu_",
use_bias: bool = True,
) -> None:
super().__init__()
assert (kernel_size - 1) % 2 == 0
if conv_context_size is None:
conv_context_size = (kernel_size - 1) // 2
assert pointwise_activation == "glu_"
dw_conv_input_dim = d_model
self.pointwise_conv1 = nn.Conv1d(
in_channels=d_model,
out_channels=d_model * 2,
kernel_size=1,
stride=1,
padding=0,
bias=use_bias,
)
self.depthwise_conv = CausalConv1D(
in_channels=dw_conv_input_dim,
out_channels=dw_conv_input_dim,
kernel_size=kernel_size,
stride=1,
padding=conv_context_size,
groups=dw_conv_input_dim,
bias=use_bias,
)
assert norm_type == "batch_norm"
self.batch_norm = nn.BatchNorm1d(dw_conv_input_dim)
self.activation = Swish()
self.pointwise_conv2 = nn.Conv1d(
in_channels=dw_conv_input_dim,
out_channels=d_model,
kernel_size=1,
stride=1,
padding=0,
bias=use_bias,
)
def forward(
self, x: torch.Tensor, pad_mask: torch.Tensor | None = None
) -> torch.Tensor:
x = x.transpose(1, 2)
x = self.pointwise_conv1(x)
x = nn.functional.glu(x, dim=1)
if pad_mask is not None:
x = x.masked_fill(pad_mask.unsqueeze(1), 0.0)
x = self.depthwise_conv(x)
x = self.batch_norm(x)
x = self.activation(x)
x = self.pointwise_conv2(x)
x = x.transpose(1, 2)
return x
class CohereASRMultiHeadAttention(nn.Module):
"""Multi-Head Attention layer of Transformer.
Args:
n_head (int): number of heads
n_feat (int): size of the features
use_bias (bool): whether to remove bias in linear and conv layers
"""
def __init__(
self,
n_head: int,
n_feat: int,
use_bias: bool = True,
) -> None:
"""Construct an MultiHeadedAttention object."""
super().__init__()
assert n_feat % n_head == 0
self.d_k = n_feat // n_head
self.s_d_k = math.sqrt(self.d_k)
self.h = n_head
self.linear_q = nn.Linear(n_feat, n_feat, bias=use_bias)
self.linear_k = nn.Linear(n_feat, n_feat, bias=use_bias)
self.linear_v = nn.Linear(n_feat, n_feat, bias=use_bias)
self.linear_out = nn.Linear(n_feat, n_feat, bias=use_bias)
def forward_qkv(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Transforms query, key and value.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value (torch.Tensor): (batch, time2, size)
returns:
q (torch.Tensor): (batch, head, time1, size)
k (torch.Tensor): (batch, head, time2, size)
v (torch.Tensor): (batch, head, time2, size)
"""
n_batch = query.size(0)
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
return q, k, v
def forward_attention(
self,
value: torch.Tensor,
scores: torch.Tensor,
mask: torch.Tensor | None,
) -> torch.Tensor:
"""Compute attention context vector.
Args:
value (torch.Tensor): (batch, time2, size)
scores(torch.Tensor): (batch, time1, time2)
mask(torch.Tensor): (batch, time1, time2)
returns:
value (torch.Tensor): transformed `value`
(batch, time2, d_model) weighted by the
attention scores
"""
n_batch = value.size(0)
if mask is not None:
mask = mask.unsqueeze(1) # (batch, 1, time1, time2)
scores = scores.masked_fill(mask, -INF_VAL)
attn = torch.softmax(scores, dim=-1).masked_fill(
mask, 0.0
) # (batch, head, time1, time2)
else:
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
x = torch.matmul(attn, value) # (batch, head, time1, d_k)
x = x.transpose(1, 2).reshape(
n_batch, -1, self.h * self.d_k
) # (batch, time1, d_model)
return self.linear_out(x) # (batch, time1, d_model)
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor | None,
pos_emb: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute 'Scaled Dot Product Attention'.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
returns:
output (torch.Tensor): transformed `value`
(batch, time1, d_model) weighted by the
query dot key attention
"""
q, k, v = self.forward_qkv(query, key, value)
scores = torch.matmul(q, k.transpose(-2, -1)) / self.s_d_k
return self.forward_attention(v, scores, mask)
class RelPositionMultiHeadAttention(CohereASRMultiHeadAttention):
"""Multi-Head Attention layer of Transformer-XL with
support of relative positional encoding.
Paper: https://arxiv.org/abs/1901.02860
Args:
n_head (int): number of heads
n_feat (int): size of the features
use_bias (bool): whether to apply bias in linear
and conv layers of MultiHeadAttention
"""
def __init__(
self,
n_head: int,
n_feat: int,
pos_bias_u: nn.Parameter | torch.Tensor | None,
pos_bias_v: nn.Parameter | torch.Tensor | None,
use_bias: bool = True,
) -> None:
"""Construct an RelPositionMultiHeadedAttention object."""
super().__init__(
n_head=n_head,
n_feat=n_feat,
use_bias=use_bias,
)
# linear transformation for positional encoding
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
# these two learnable biases are used in matrix c and matrix d
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
if pos_bias_u is None or pos_bias_v is None:
self.pos_bias_u = nn.Parameter(
torch.zeros(self.h, self.d_k), requires_grad=False
)
self.pos_bias_v = nn.Parameter(
torch.zeros(self.h, self.d_k), requires_grad=False
)
else:
self.pos_bias_u = pos_bias_u
self.pos_bias_v = pos_bias_v
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
"""Compute relative positional encoding.
Args:
x (torch.Tensor): (batch, nheads, time, 2*time-1)
"""
b, h, qlen, pos_len = x.size() # (b, h, t1, t2)
# need to add a column of zeros on the left side of
# last dimension to perform the relative shifting
x = torch.nn.functional.pad(x, pad=(1, 0)) # (b, h, t1, t2+1)
x = x.view(b, h, -1, qlen) # (b, h, t2+1, t1)
# need to drop the first row
x = x[:, :, 1:].view(b, h, qlen, pos_len) # (b, h, t1, t2)
return x
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
mask: torch.Tensor | None,
pos_emb: torch.Tensor | None = None,
) -> torch.Tensor:
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
Args:
query (torch.Tensor): (batch, time1, size)
key (torch.Tensor): (batch, time2, size)
value(torch.Tensor): (batch, time2, size)
mask (torch.Tensor): (batch, time1, time2)
pos_emb (torch.Tensor) : (batch, time1, size)
Returns:
output (torch.Tensor): transformed `value`
(batch, time1, d_model) weighted by the
query dot key attention
"""
q, k, v = self.forward_qkv(query, key, value)
q = q.transpose(1, 2) # (batch, time1, head, d_k)
n_batch_pos = pos_emb.size(0)
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
# (batch, head, time1, time2)
# compute matrix b and matrix d
# (batch, head, time1, time2)
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
matrix_bd = self.rel_shift(matrix_bd)
# drops extra elements in the matrix_bd to match the matrix_ac's size
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
matrix_bd = matrix_bd[:, :, :, : matrix_ac.size(-1)]
scores = (matrix_ac + matrix_bd) / self.s_d_k # (batch, head, time1, time2)
return self.forward_attention(v, scores, mask)
class ConformerLayer(torch.nn.Module):
"""A single block of the Conformer encoder.
Args:
d_model (int): input dimension of
MultiheadAttentionMechanism and
PositionwiseFeedForward
d_ff (int): hidden dimension of
PositionwiseFeedForward
self_attention_model (str): type of the attention
layer and positional encoding
n_heads (int): number of heads for multi-head
attention
conv_kernel_size (int): kernel size for depthwise
convolution in convolution module
use_bias (bool): Apply bias to all Linear and
Conv1d layers from each ConformerLayer to
improve activation flow and stabilize training
of huge models. Defaults to True.
"""
def __init__(
self,
d_model: int,
d_ff: int,
self_attention_model: str = "rel_pos",
n_heads: int = 4,
conv_kernel_size: int = 31,
conv_norm_type: str = "batch_norm",
conv_context_size: int | None = None,
pos_bias_u: nn.Parameter | torch.Tensor | None = None,
pos_bias_v: nn.Parameter | torch.Tensor | None = None,
att_context_size: list[int] | None = None,
use_bias: bool = True,
) -> None:
super().__init__()
if att_context_size is None:
att_context_size = [-1, -1]
self.self_attention_model = self_attention_model
self.fc_factor = 0.5
# first feed forward module
self.norm_feed_forward1 = nn.LayerNorm(d_model)
self.feed_forward1 = ConformerFeedForward(
d_model=d_model, d_ff=d_ff, use_bias=use_bias
)
# convolution module
self.norm_conv = nn.LayerNorm(d_model)
self.conv = ConformerConvolution(
d_model=d_model,
kernel_size=conv_kernel_size,
norm_type=conv_norm_type,
conv_context_size=conv_context_size,
use_bias=use_bias,
)
# multi-headed self-attention module
self.norm_self_att = nn.LayerNorm(d_model)
assert self_attention_model == "rel_pos"
self.self_attn = RelPositionMultiHeadAttention(
n_head=n_heads,
n_feat=d_model,
pos_bias_u=pos_bias_u,
pos_bias_v=pos_bias_v,
use_bias=use_bias,
)
# second feed forward module
self.norm_feed_forward2 = nn.LayerNorm(d_model)
self.feed_forward2 = ConformerFeedForward(
d_model=d_model, d_ff=d_ff, use_bias=use_bias
)
self.norm_out = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
att_mask: torch.Tensor | None = None,
pos_emb: torch.Tensor | None = None,
pad_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
x (torch.Tensor): input signals (B, T, d_model)
att_mask (torch.Tensor): attention masks(B, T, T)
pos_emb (torch.Tensor): (L, 1, d_model)
pad_mask (torch.tensor): padding mask
Returns:
x (torch.Tensor): (B, T, d_model)
"""
residual = x
x = self.norm_feed_forward1(x)
x = self.feed_forward1(x)
residual = residual + x * self.fc_factor
x = self.norm_self_att(residual)
if self.self_attention_model == "rel_pos":
x = self.self_attn(
query=x,
key=x,
value=x,
mask=att_mask,
pos_emb=pos_emb,
)
elif self.self_attention_model == "rel_pos_local_attn":
x = self.self_attn(
query=x,
key=x,
value=x,
pad_mask=pad_mask,
pos_emb=pos_emb,
)
elif self.self_attention_model == "abs_pos":
x = self.self_attn(query=x, key=x, value=x, mask=att_mask)
else:
x = None
residual = residual + x
x = self.norm_conv(residual)
x = self.conv(x, pad_mask=pad_mask)
residual = residual + x
x = self.norm_feed_forward2(residual)
x = self.feed_forward2(x)
residual = residual + x * self.fc_factor
x = self.norm_out(residual)
return x
class ConformerEncoder(nn.Module):
"""
The encoder for ASR model of Conformer.
Based on this paper:
'Conformer: Convolution-augmented Transformer for
Speech Recognition' by Anmol Gulati et al.
https://arxiv.org/abs/2005.08100
"""
def __init__(self, *, vllm_config: VllmConfig):
super().__init__()
self.hf_config = vllm_config.model_config.hf_config
feat_in = self.hf_config.encoder["feat_in"]
n_layers = self.hf_config.encoder["n_layers"]
d_model = self.hf_config.encoder["d_model"]
feat_out = self.hf_config.encoder["feat_out"]
causal_downsampling = self.hf_config.encoder["causal_downsampling"]
subsampling = self.hf_config.encoder["subsampling"]
subsampling_factor = self.hf_config.encoder["subsampling_factor"]
subsampling_conv_chunking_factor = self.hf_config.encoder.get(
"subsampling_conv_chunking_factor", 1
)
subsampling_conv_channels = self.hf_config.encoder["subsampling_conv_channels"]
ff_expansion_factor = self.hf_config.encoder["ff_expansion_factor"]
self_attention_model = self.hf_config.encoder["self_attention_model"]
n_heads = self.hf_config.encoder["n_heads"]
att_context_size = self.hf_config.encoder["att_context_size"]
att_context_probs = self.hf_config.encoder.get("att_context_probs", None)
att_context_style = self.hf_config.encoder.get("att_context_style", "regular")
xscaling = self.hf_config.encoder["xscaling"]
untie_biases = self.hf_config.encoder["untie_biases"]
pos_emb_max_len = self.hf_config.encoder["pos_emb_max_len"]
conv_kernel_size = self.hf_config.encoder["conv_kernel_size"]
conv_norm_type = self.hf_config.encoder["conv_norm_type"]
conv_context_size = self.hf_config.encoder["conv_context_size"]
use_bias = self.hf_config.encoder.get("use_bias", True)
d_ff = d_model * ff_expansion_factor
self.d_model = d_model
self._feat_in = feat_in
self.att_context_style = att_context_style
self.subsampling_factor = subsampling_factor
self.self_attention_model = self_attention_model
# Setting up the att_context_size
(
_,
self.att_context_size,
_,
self.conv_context_size,
) = self._calc_context_sizes(
att_context_style=att_context_style,
att_context_size=att_context_size,
att_context_probs=att_context_probs,
conv_context_size=conv_context_size,
conv_kernel_size=conv_kernel_size,
)
if xscaling:
self.xscale = math.sqrt(d_model)
else:
self.xscale = None
# Subsampling
if subsampling_conv_channels == -1:
subsampling_conv_channels = d_model
assert subsampling and subsampling_factor > 1 and subsampling == "dw_striding"
self.pre_encode = ConvSubsampling(
subsampling=subsampling,
subsampling_factor=subsampling_factor,
feat_in=feat_in,
feat_out=d_model,
conv_channels=subsampling_conv_channels,
subsampling_conv_chunking_factor=subsampling_conv_chunking_factor,
activation=nn.ReLU(True),
is_causal=causal_downsampling,
)
self._feat_out = d_model
# Biases for relative positional encoding
if not untie_biases and self_attention_model == "rel_pos":
d_head = d_model // n_heads
# Register as buffers instead of parameters since they're not trainable
# and need to respect dtype during weight loading
self.register_buffer(
"pos_bias_u", torch.zeros(n_heads, d_head), persistent=True
)
self.register_buffer(
"pos_bias_v", torch.zeros(n_heads, d_head), persistent=True
)
pos_bias_u = self.pos_bias_u
pos_bias_v = self.pos_bias_v
else:
pos_bias_u = None
pos_bias_v = None
# Positional encodings
self.pos_emb_max_len = pos_emb_max_len
assert self_attention_model == "rel_pos"
self.pos_enc = RelPositionalEncoding(
d_model=d_model,
max_len=pos_emb_max_len,
xscale=self.xscale,
)
self.layers = nn.ModuleList()
for i in range(n_layers):
layer = ConformerLayer(
d_model=d_model,
d_ff=d_ff,
self_attention_model=self_attention_model,
n_heads=n_heads,
conv_kernel_size=conv_kernel_size,
conv_norm_type=conv_norm_type,
conv_context_size=self.conv_context_size,
pos_bias_u=pos_bias_u,
pos_bias_v=pos_bias_v,
att_context_size=self.att_context_size,
use_bias=use_bias,
)
self.layers.append(layer)
if feat_out > 0 and feat_out != self._feat_out:
self.out_proj = nn.Linear(self._feat_out, feat_out)
self._feat_out = feat_out
else:
self.out_proj = None
self._feat_out = d_model
self.set_max_audio_length(self.pos_emb_max_len)
def get_num_encoder_cross_attn_tokens(self, num_encoder_input_tokens: int) -> int:
num_encoder_cross_attn_tokens = math.ceil(
num_encoder_input_tokens / self.subsampling_factor
)
return num_encoder_cross_attn_tokens
def set_max_audio_length(self, max_audio_length: int) -> None:
"""
Sets maximum input length.
Pre-calculates internal seq_range mask.
Args:
max_audio_length (int): New maximum sequence length.
"""
device = next(self.parameters()).device
dtype = next(self.parameters()).dtype
self.pos_enc.extend_pe(max_audio_length, device, dtype)
def forward(
self,
audio_signal: torch.Tensor,
length: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
if audio_signal.shape[-2] != self._feat_in:
raise ValueError(
f"audio_signal should have shape "
f"(batch, {self._feat_in}, n_frame) but "
f"got last dimension "
f"{audio_signal.shape[-2]}."
)
return self.forward_internal(
audio_signal,
length,
)
def forward_internal(
self,
audio_signal: torch.Tensor,
length: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if length is None:
length = audio_signal.new_full(
(audio_signal.size(0),),
audio_signal.size(-1),
dtype=torch.int64,
device=audio_signal.device,
)
cur_att_context_size = self.att_context_size
audio_signal = torch.transpose(audio_signal, 1, 2)
audio_signal, length = self.pre_encode(x=audio_signal, lengths=length)
length = length.to(torch.int64)
max_audio_length = audio_signal.size(1)
padding_length = length
audio_signal, pos_emb = self.pos_enc(x=audio_signal, cache_len=0)
pad_mask, att_mask = self._create_masks(
att_context_size=cur_att_context_size,
padding_length=padding_length,
max_audio_length=max_audio_length,
offset=None,
device=audio_signal.device,
)
for lth, layer in enumerate(self.layers):
audio_signal = layer(
x=audio_signal,
att_mask=att_mask,
pos_emb=pos_emb,
pad_mask=pad_mask,
)
if self.out_proj is not None:
audio_signal = self.out_proj(audio_signal)
audio_signal = torch.transpose(audio_signal, 1, 2)
length = length.to(dtype=torch.int64)
return audio_signal, length
def _create_masks(
self,
att_context_size: list[int],
padding_length: torch.Tensor,
max_audio_length: int,
offset: torch.Tensor | None,
device: torch.device,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if self.self_attention_model != "rel_pos_local_attn":
att_mask = torch.ones(
1, max_audio_length, max_audio_length, dtype=torch.bool, device=device
)
if self.att_context_style == "regular":
if att_context_size[0] >= 0:
att_mask = att_mask.triu(diagonal=-att_context_size[0])
if att_context_size[1] >= 0:
att_mask = att_mask.tril(diagonal=att_context_size[1])
elif self.att_context_style == "chunked_limited":
# When right context is unlimited, just the
# left side of masking needs to get updated
if att_context_size[1] == -1:
if att_context_size[0] >= 0:
att_mask = att_mask.triu(diagonal=-att_context_size[0])
else:
chunk_size = att_context_size[1] + 1
# left_chunks_num specifies the number
# of chunks to be visible by each chunk
# on the left side
if att_context_size[0] >= 0:
left_chunks_num = att_context_size[0] // chunk_size
else:
left_chunks_num = 10000
chunk_idx = torch.arange(
0, max_audio_length, dtype=torch.int, device=att_mask.device
)
chunk_idx = torch.div(chunk_idx, chunk_size, rounding_mode="trunc")
diff_chunks = chunk_idx.unsqueeze(1) - chunk_idx.unsqueeze(0)
chunked_limited_mask = torch.logical_and(
torch.le(diff_chunks, left_chunks_num), torch.ge(diff_chunks, 0)
)
att_mask = torch.logical_and(
att_mask, chunked_limited_mask.unsqueeze(0)
)
else:
att_mask = None
# pad_mask is the masking to be used to ignore paddings
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) < padding_length.unsqueeze(-1)
if offset is not None:
pad_mask_off = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) >= offset.unsqueeze(-1)
pad_mask = pad_mask_off.logical_and(pad_mask)
if att_mask is not None:
# pad_mask_for_att_mask is the mask which helps to ignore paddings
pad_mask_for_att_mask = pad_mask.unsqueeze(1).repeat(
[1, max_audio_length, 1]
)
pad_mask_for_att_mask = torch.logical_and(
pad_mask_for_att_mask, pad_mask_for_att_mask.transpose(1, 2)
)
# att_mask is the masking to be used by MHA
# layers to ignore tokens not supposed to be
# visible
att_mask = att_mask[:, :max_audio_length, :max_audio_length]
# paddings should also get ignored, so
# pad_mask_for_att_mask is used to ignore their
# corresponding scores
att_mask = torch.logical_and(
pad_mask_for_att_mask, att_mask.to(pad_mask_for_att_mask.device)
)
att_mask = ~att_mask
pad_mask = ~pad_mask
return pad_mask, att_mask
def _calc_context_sizes(
self,
att_context_size: list[int] | list[list[int]] | None,
att_context_probs: list[float] | None,
att_context_style: str,
conv_context_size: list[int] | str | None,
conv_kernel_size: int,
) -> tuple[list[list[int]], list[int], list[float], list[int]]:
# convert att_context_size to a standard list of lists
if att_context_size:
att_context_size_all = list(att_context_size)
if isinstance(att_context_size_all[0], int):
att_context_size_all = [att_context_size_all]
for i, att_cs in enumerate(att_context_size_all):
if att_context_style == "chunked_limited":
if att_cs[0] > 0 and att_cs[0] % (att_cs[1] + 1) > 0:
raise ValueError(
f"att_context_size[{i}][0] % "
f"(att_context_size[{i}][1]"
f" + 1) should be zero!"
)
if att_cs[1] < 0 and len(att_context_size_all) <= 1:
raise ValueError(
f"Right context "
f"(att_context_size[{i}][1])"
f" can not be unlimited for"
f" chunked_limited style!"
)
else:
att_context_size_all = [[-1, -1]]
if att_context_probs:
if len(att_context_probs) != len(att_context_size_all):
raise ValueError(
"The size of the att_context_probs "
"should be the same as att_context_size."
)
att_context_probs = list(att_context_probs)
if sum(att_context_probs) != 1:
raise ValueError(
"The sum of numbers in "
"att_context_probs should be equal "
"to one to be a distribution."
)
else:
att_context_probs = [1.0 / len(att_context_size_all)] * len(
att_context_size_all
)
if conv_context_size is not None:
if not isinstance(conv_context_size, list) and not isinstance(
conv_context_size, str
):
raise ValueError(
"Invalid conv_context_size! It should "
"be the string 'causal' or a list of "
"two integers."
)
if conv_context_size == "causal":
conv_context_size = [conv_kernel_size - 1, 0]
else:
total = conv_context_size[0] + conv_context_size[1] + 1
if total != conv_kernel_size:
raise ValueError(
f"Invalid conv_context_size: {self.conv_context_size}!"
)
else:
conv_context_size = [
(conv_kernel_size - 1) // 2,
(conv_kernel_size - 1) // 2,
]
return (
att_context_size_all,
att_context_size_all[0],
att_context_probs,
conv_context_size,
)
# ----- Encoder END -----
# This subclass is specific to vLLM in order for
# `_mark_composite_model` to target this module
class CohereASRProjector(nn.Linear):
pass
class CohereASRModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.encoder = ConformerEncoder(vllm_config=vllm_config)
self.decoder = CohereASRDecoder(
vllm_config=vllm_config, prefix=f"{prefix}.decoder"
)
if self.encoder.d_model != self.decoder.hidden_size:
self.encoder_decoder_proj = CohereASRProjector(
self.encoder.d_model, self.decoder.hidden_size
)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor],
) -> torch.Tensor:
enc_states = torch.cat(encoder_outputs, dim=0) if len(encoder_outputs) else None
decoder_outputs = self.decoder(
input_ids=input_ids,
positions=positions,
encoder_hidden_states=enc_states,
)
return decoder_outputs
def get_encoder_outputs(
self,
input_features: torch.Tensor | list[torch.Tensor] | None,
seq_lens: torch.Tensor | None,
) -> torch.Tensor | None:
if input_features is None:
return None
if isinstance(input_features, torch.Tensor):
encoder_input_length = seq_lens
out, encoder_output_length = self.encoder(
input_features, length=encoder_input_length
) # B x D x T
out = out.permute(0, 2, 1)
if hasattr(self, "encoder_decoder_proj"):
out = self.encoder_decoder_proj(out)
# Convert padded tensor to packed
outs = []
for i, feat in enumerate(out):
feat_len = encoder_output_length[i]
outs.append(feat[:feat_len, :])
return outs
else:
raise NotImplementedError("List input_features not supported")
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".first_sub_layer.qkv_proj", ".first_sub_layer.query_net", "q"),
(".first_sub_layer.qkv_proj", ".first_sub_layer.key_net", "k"),
(".first_sub_layer.qkv_proj", ".first_sub_layer.value_net", "v"),
(".second_sub_layer.kv_proj", ".second_sub_layer.key_net", "k"),
(".second_sub_layer.kv_proj", ".second_sub_layer.value_net", "v"),
]
params_dict = dict(self.named_parameters())
buffers_dict = dict(self.named_buffers())
params_dict.update(buffers_dict)
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
# if name.endswith(".bias") and name not in params_dict:
# continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
# Convert buffer dtype to match loaded weight for pos_bias tensors
if "pos_bias" in name and param.dtype != loaded_weight.dtype:
logger.info(
"Converting buffer %s dtype from %s to %s for loading.",
name,
param.dtype,
loaded_weight.dtype,
)
param.data = param.data.to(loaded_weight.dtype)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CohereASRProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> PretrainedConfig:
return self.ctx.get_hf_config()
def get_default_tok_params(self) -> TokenizeParams:
# Special tokens should be provided by the user based on the
# task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
return super().get_default_tok_params().with_kwargs(add_special_tokens=False)
def get_hf_processor(self, **kwargs: object) -> CohereASRProcessor:
if not hasattr(self, "_cached_hf_processor"):
hf_config = self.get_hf_config()
preproc = hf_config.preprocessor
sample_rate = preproc.get("sample_rate", 16000)
window_size = preproc.get("window_size", 0.02)
window_stride = preproc.get("window_stride", 0.01)
feature_extractor = CohereASRFeatureExtractor(
feature_size=preproc.get("features", 64),
sampling_rate=sample_rate,
padding_value=preproc.get("pad_value", 0.0),
max_duration=hf_config.max_audio_clip_s,
n_window_size=int(window_size * sample_rate),
n_window_stride=int(window_stride * sample_rate),
window=preproc.get("window", "hann"),
normalize=preproc.get("normalize", "per_feature"),
n_fft=preproc.get("n_fft", None),
preemph=preproc.get("preemph", 0.97),
lowfreq=preproc.get("lowfreq", 0),
highfreq=preproc.get("highfreq", None),
log=preproc.get("log", True),
log_zero_guard_type=preproc.get("log_zero_guard_type", "add"),
log_zero_guard_value=preproc.get("log_zero_guard_value", 2**-24),
dither=preproc.get("dither", 1e-05),
pad_to=preproc.get("pad_to", 16),
frame_splicing=preproc.get("frame_splicing", 1),
exact_pad=preproc.get("exact_pad", False),
mag_power=preproc.get("mag_power", 2.0),
mel_norm=preproc.get("mel_norm", "slaney"),
stft_exact_pad=preproc.get("stft_exact_pad", False),
stft_conv=preproc.get("stft_conv", False),
device="cpu",
)
tokenizer = self.ctx.tokenizer
self._cached_hf_processor = CohereASRProcessor(
feature_extractor=feature_extractor,
tokenizer=tokenizer,
)
return self._cached_hf_processor
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": 1}
def get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def get_feature_extractor(self, **kwargs: object) -> CohereASRFeatureExtractor:
hf_processor = self.get_hf_processor(**kwargs)
feature_extractor = hf_processor.feature_extractor
assert isinstance(feature_extractor, CohereASRFeatureExtractor)
return feature_extractor
def get_num_audio_tokens(self, num_samples: int) -> int:
num_tokens = self.get_feature_extractor().get_seq_len(num_samples)
config = self.get_hf_config()
subsampling_factor = config.encoder["subsampling_factor"]
num_tokens = math.ceil(num_tokens / subsampling_factor)
return num_tokens
class CohereASRDummyInputsBuilder(BaseDummyInputsBuilder[CohereASRProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_audios = mm_counts.get("audio", 0)
return "<|startoftranscript|>" * num_audios
def get_dummy_mm_data(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options=None,
mm_processor_kwargs=None,
) -> MultiModalDataDict:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.max_duration * sampling_rate
num_audios = mm_counts.get("audio", 0)
return {
"audio": self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
class CohereASRMultiModalProcessor(EncDecMultiModalProcessor[CohereASRProcessingInfo]):
skip_decoder_start_token: bool = True
@property
def pad_dummy_encoder_prompt(self) -> bool:
return True
def create_encoder_prompt(
self,
prompt: str | list[int],
mm_items: MultiModalDataItems,
) -> str | list[int]:
return [0]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
):
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
input_features=MultiModalFieldConfig.batched("audio"),
length=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
def get_audio_replacement_cohere_asr(item_idx: int):
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
num_tokens = self.info.get_num_audio_tokens(num_samples=audio_len)
return [0] * num_tokens
return [
PromptReplacement(
modality="audio",
target=[0],
replacement=get_audio_replacement_cohere_asr,
)
]
@MULTIMODAL_REGISTRY.register_processor(
CohereASRMultiModalProcessor,
info=CohereASRProcessingInfo,
dummy_inputs=CohereASRDummyInputsBuilder,
)
class CohereAsrForConditionalGeneration(
nn.Module, SupportsTranscription, SupportsMultiModal
):
packed_modules_mapping = {
"self_attn.qkv_proj": [
"self_attn.q_proj",
"self_attn.k_proj",
"self_attn.v_proj",
],
"encoder_attn.kv_proj": ["encoder_attn.k_proj", "encoder_attn.v_proj"],
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_substr={".fc1.": ".mlp.fc1.", ".fc2.": ".mlp.fc2."}
)
supports_transcription_only = True
supported_languages = ISO639_1_SUPPORTED_LANGS
skip_warmup_audio_preprocessing = True
@classmethod
def validate_language(cls, language: str | None) -> str | None:
if language is None:
logger.warning(
"Defaulting to language='en'. If you wish to transcribe "
"audio in a different language, pass the `language` field "
"in the TranscriptionRequest."
)
language = "en"
return super().validate_language(language)
@classmethod
def get_generation_prompt(
cls,
audio: np.ndarray,
model_config: ModelConfig, # not needed here
stt_config: SpeechToTextConfig,
language: str | None,
task_type: Literal["transcribe", "translate"],
request_prompt: str,
to_language: str | None,
) -> PromptType:
if language is None:
raise ValueError(
"Language must be specified when creating the CohereASR prompt"
)
# NOTE: this function is used only by online inference and not offline inference
# CohereASR doesnt have encoder prompt
language_tag = f"<|{language}|><|{language}|>"
pnc = True # TODO(ekagra): make this configurable later
pnc_tag = "<|pnc|>" if pnc else "<|nopnc|>"
default_prompt = (
f"<|startofcontext|><|startoftranscript|>"
f"<|emo:undefined|>{language_tag}{pnc_tag}"
f"<|noitn|><|notimestamp|><|nodiarize|>"
)
prompt_text = request_prompt if request_prompt else default_prompt
return TextPrompt(
prompt=prompt_text,
multi_modal_data={"audio": (audio, stt_config.sample_rate)},
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
# Required as part of SupportsMultiModal interface.
if modality.startswith("audio"):
return None
raise ValueError("Only audio modality is supported")
@classmethod
def get_speech_to_text_config(
cls, model_config: ModelConfig, task_type: str
) -> SpeechToTextConfig:
sampling_rate = model_config.hf_config.sample_rate
assert sampling_rate == 16000
max_audio_clip_s = model_config.hf_config.max_audio_clip_s
overlap_chunk_second = model_config.hf_config.overlap_chunk_second
return SpeechToTextConfig(
max_audio_clip_s=max_audio_clip_s,
overlap_chunk_second=overlap_chunk_second,
sample_rate=sampling_rate,
)
@classmethod
def get_num_audio_tokens(
cls,
audio_duration_s: float,
stt_config: SpeechToTextConfig,
model_config: ModelConfig,
) -> int | None:
hop_length = model_config.hf_config.preprocessor.get("window_stride")
assert hop_length is not None
return math.ceil(audio_duration_s * stt_config.sample_rate / hop_length)
def get_num_encoder_cross_attn_tokens(self, num_encoder_input_tokens: int) -> int:
return self.model.encoder.get_num_encoder_cross_attn_tokens(
num_encoder_input_tokens
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.dtype = vllm_config.model_config.dtype
with self._mark_composite_model(
vllm_config,
language_targets=CohereASRDecoder,
tower_targets={"audio": (ConformerEncoder, CohereASRProjector)},
):
self.model = CohereASRModel(vllm_config=vllm_config, prefix=prefix)
head_config = config.head
self.proj_out = ParallelLMHead(
head_config["num_classes"],
head_config["hidden_size"],
quant_config=quant_config,
bias=True,
) # NOTE: bias is True
logit_scale = getattr(head_config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(
head_config["num_classes"], scale=logit_scale
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
encoder_outputs: list[torch.Tensor] | None = None,
**kwargs,
) -> torch.Tensor:
if encoder_outputs is None:
encoder_outputs = []
decoder_outputs = self.model(
input_ids=input_ids,
positions=positions,
encoder_outputs=encoder_outputs,
)
return decoder_outputs
def get_language_model(self) -> torch.nn.Module:
# Required as part of SupportsMultiModal interface.
return self.model.decoder
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
# Required as part of SupportsMultiModal interface.
audio_input, seq_lens = self._parse_and_validate_audio_input(**kwargs)
if hasattr(audio_input, "input_features"):
out = self.model.get_encoder_outputs(audio_input["input_features"])
else:
out = self.model.get_encoder_outputs(audio_input, seq_lens)
return out
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> tuple[torch.Tensor, torch.Tensor]:
input_features = kwargs.pop("input_features", None)
length = kwargs.pop("length", None)
if input_features is None:
raise ValueError("Audio features are required for CohereASR model.")
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError(
f"Incorrect type of audio features. Got type: {type(input_features)}"
)
if isinstance(input_features, torch.Tensor):
seq_lens = length.reshape(-1)
else:
input_features = [
feat.to(self.dtype).squeeze(0).transpose(1, 0)
for feat in input_features
]
seq_lens = length.reshape(-1)
input_features = torch.nn.utils.rnn.pad_sequence(
input_features, batch_first=True, padding_value=0.0
)
input_features = input_features.transpose(1, 2)
return input_features, seq_lens
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.proj_out, hidden_states, self.proj_out.bias)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
def transform(inputs):
name, loaded_weight = inputs
if name.startswith("transf_decoder._decoder"):
name = name.replace("transf_decoder._decoder", "decoder")
if name.startswith("transf_decoder._embedding"):
name = name.replace("transf_decoder._embedding", "decoder.embedding")
if "second_sub_layer.query_net" in name:
name = name.replace(
"second_sub_layer.query_net", "second_sub_layer.q_proj"
)
if name in ["log_softmax.mlp.layer0.weight", "log_softmax.mlp.layer0.bias"]:
name = name.replace("log_softmax.mlp.layer0", "proj_out")
else:
name = "model." + name
return name, loaded_weight
loader = AutoWeightsLoader(
self,
skip_prefixes=[
"model.preprocessor.featurizer.fb",
"model.preprocessor.featurizer.window",
],
skip_substrs=["model.conv.batch_norm.num_batches_tracked"],
)
return loader.load_weights(
map(transform, weights), mapper=self.hf_to_vllm_mapper
)