Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -14,15 +14,24 @@ import torch
import torch.nn.functional as F
from torch import Tensor, nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointWrapper)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel)
CheckpointWrapper,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
from transformers import PretrainedConfig
from vllm.model_executor.models.phi4mm_utils import (
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
AbsolutePositionalEncoding,
ConvModule,
FeedForward,
MeanVarianceNormLayer,
MultiHeadedAttention,
MultiSequential,
NemoConvSubsampling,
T5RelativeAttentionLogitBias,
adaptive_enc_mask,
get_offset,
unfold_tensor,
)
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
@@ -40,9 +49,9 @@ class ConformerEncoderLayer(nn.Module):
if > 0, ext_pw_out_channel is a dim channel size
for the last pointwise conv after swish activation.
depthwise_seperable_out_channel: int
if set different to 0, the number of
if set different to 0, the number of
depthwise_seperable_out_channel will be used as a
channel_out of the second conv1d layer.
channel_out of the second conv1d layer.
otherwise, it equals to 0, the second conv1d layer is skipped.
depthwise_multiplier: int
number of input_dim channels duplication. this value
@@ -119,10 +128,10 @@ class ConformerEncoderLayer(nn.Module):
and allow the onnx conversion for inference.
default False.
use_pt_scaled_dot_product_attention: bool, optional
if set to True, use pytorch's scaled dot product attention
if set to True, use pytorch's scaled dot product attention
implementation in training.
attn_group_sizes: int, optional
the number of groups to use for attention, default 1
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attn_group_sizes < attention_heads = Grouped-Query Attention
@@ -173,8 +182,7 @@ class ConformerEncoderLayer(nn.Module):
attention_inner_dim,
attention_glu_type,
bias_in_glu,
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
group_size=attn_group_sizes,
)
self.conv = ConvModule(
@@ -296,7 +304,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
(Q*K^T + B) implemented in cmb.basics.embedding.
[T5/ALiBi]RelativeAttentionLogitBias
usage: relative_attention_bias_args={"type": t5/alibi}
additional method-specific arguments can be provided (see
additional method-specific arguments can be provided (see
transformer_base.py)
positional_dropout_rate: float, optional
dropout rate after positional encoding. default 0.0
@@ -310,10 +318,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
supraframe utts in batch.
Default: none
attention_group_size: int, optional
the number of groups to use for attention, default 1
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attention_group_size < attention_heads = Grouped-Query
1 < attention_group_size < attention_heads = Grouped-Query
Attention
attention_group_size = attention_heads = Multi-Query Attention
"""
@@ -334,8 +342,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
relative_attention_bias_args: Optional[dict[str, Any]] = None,
positional_dropout_rate: float = 0.0,
nemo_conv_settings: Optional[dict[str, Any]] = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none",
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
attention_group_size: int = 1,
encoder_embedding_config: Optional[dict[str, Any]] = None,
) -> None:
@@ -366,70 +373,77 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
if nemo_conv_settings:
default_nemo_conv_settings.update(nemo_conv_settings)
for i in ["subsampling_factor", "feat_in", "feat_out"]:
assert (
i not in nemo_conv_settings
), "{i} should be specified outside of the NeMo dictionary"
assert i not in nemo_conv_settings, (
"{i} should be specified outside of the NeMo dictionary"
)
self.embed = NemoConvSubsampling(**default_nemo_conv_settings, )
self.embed = NemoConvSubsampling(
**default_nemo_conv_settings,
)
else:
raise ValueError("unknown input_layer: " + input_layer)
self.pos_emb = AbsolutePositionalEncoding(attention_dim,
positional_dropout_rate)
self.pos_emb = AbsolutePositionalEncoding(
attention_dim, positional_dropout_rate
)
self.relative_attention_bias_type = (
relative_attention_bias_args.get("type")
if relative_attention_bias_args else None)
if relative_attention_bias_args
else None
)
if self.relative_attention_bias_type == "t5":
assert (self.num_heads % self.attention_group_size == 0
), "attention_group_size must divide n_head"
assert self.num_heads % self.attention_group_size == 0, (
"attention_group_size must divide n_head"
)
self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
self.num_heads // self.attention_group_size,
max_distance=relative_attention_bias_args.get(
"t5_bias_max_distance", 1000),
symmetric=relative_attention_bias_args.get(
"t5_bias_symmetric", False),
"t5_bias_max_distance", 1000
),
symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
)
else:
raise NotImplementedError
self.encoder_embedding = MeanVarianceNormLayer(
self.encoder_embedding_config["input_size"])
self.encoder_embedding_config["input_size"]
)
def compute_lens_change(
self,
feature_lens: Union[int,
torch.Tensor]) -> Union[int, torch.Tensor]:
self, feature_lens: Union[int, torch.Tensor]
) -> Union[int, torch.Tensor]:
"""feature_lens: int
return updated feature lens.
This used to return a different lambda function for each case that
computed the right thing. That does not work within Torchscript.
This used to return a different lambda function for each case that
computed the right thing. That does not work within Torchscript.
If you really need this to be faster, create nn.Module()-s for all
the cases and return one of them. Torchscript does support that.
"""
if self.input_layer == "nemo_conv":
# Handle the special causal case
subsampling_causal_cond = self.nemo_conv_settings.get(
"subsampling", "dw_striding") in [
"dw_striding",
"striding",
"striding_conv1d",
]
"subsampling", "dw_striding"
) in [
"dw_striding",
"striding",
"striding_conv1d",
]
is_causal = self.nemo_conv_settings.get("is_causal", False)
if is_causal and subsampling_causal_cond:
lens_change = (torch.ceil(feature_lens /
self.time_reduction).long()
if isinstance(feature_lens, Tensor) else
math.ceil(feature_lens / self.time_reduction))
lens_change = (
torch.ceil(feature_lens / self.time_reduction).long()
if isinstance(feature_lens, Tensor)
else math.ceil(feature_lens / self.time_reduction)
)
feature_lens_remainder = feature_lens % self.time_reduction
if isinstance(feature_lens, Tensor):
lens_change[feature_lens_remainder != 1] += 1
elif feature_lens_remainder != 1:
lens_change += 1
return lens_change
ceil_func = (math.ceil
if isinstance(feature_lens, int) else torch.ceil)
ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
return ceil_func(feature_lens / self.time_reduction)
@abc.abstractmethod
@@ -437,10 +451,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
"""Abstract forward method implementation."""
def _chunk_size_selection(
self,
chunk_size: Optional[Union[int, list[int]]] = None,
left_chunk: Optional[Union[int,
list[int]]] = None) -> tuple[int, int]:
self,
chunk_size: Optional[Union[int, list[int]]] = None,
left_chunk: Optional[Union[int, list[int]]] = None,
) -> tuple[int, int]:
"""If chunk size is a list, we will randomly select a chunk size."""
if chunk_size is None:
@@ -450,15 +464,16 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
if isinstance(chunk_size, list):
# Variable chunk size during training
chunk_size_index = int(
torch.randint(low=0, high=len(chunk_size), size=(1, )))
torch.randint(low=0, high=len(chunk_size), size=(1,))
)
chunk_size_train_eff = chunk_size[chunk_size_index]
if not isinstance(left_chunk, list):
raise ValueError(
"Since chunk_size is a list, left_chunk must be a list")
"Since chunk_size is a list, left_chunk must be a list"
)
if len(left_chunk) != len(chunk_size):
raise ValueError(
"The length of left_chunk must be the same as length of "\
"chunk_size."
"The length of left_chunk must be the same as length of chunk_size."
)
left_chunk_train_eff = left_chunk[chunk_size_index]
else:
@@ -479,8 +494,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
return embed_class
def _forward_embeddings_core(
self, input_tensor: torch.Tensor,
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
self, input_tensor: torch.Tensor, masks: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
embed_class = self._get_embed_class(self.embed)
assert isinstance(embed_class, NemoConvSubsampling)
input_tensor, masks = self.embed(input_tensor, masks)
@@ -493,23 +508,32 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
pos_v = None
if self.relative_attention_bias_layer is None:
input_tensor = self.pos_emb(
input_tensor) # default to add abs sinusoid embedding
input_tensor
) # default to add abs sinusoid embedding
return pos_k, pos_v
def _streaming_mask(self, seq_len: int, batch_size: int,
chunk_size: Union[int, list[int]],
left_chunk: Union[int, list[int]]) -> torch.Tensor:
chunk_size_train_eff, left_chunk_train_eff = \
self._chunk_size_selection(chunk_size, left_chunk)
def _streaming_mask(
self,
seq_len: int,
batch_size: int,
chunk_size: Union[int, list[int]],
left_chunk: Union[int, list[int]],
) -> torch.Tensor:
chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
chunk_size, left_chunk
)
# Create mask matrix for streaming
# S stores start index. if chunksize is 18, s is [0,18,36,....]
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
enc_streaming_mask = (adaptive_enc_mask(
seq_len, chunk_start_idx,
left_window=left_chunk_train_eff).unsqueeze(0).expand(
[batch_size, -1, -1]))
enc_streaming_mask = (
adaptive_enc_mask(
seq_len, chunk_start_idx, left_window=left_chunk_train_eff
)
.unsqueeze(0)
.expand([batch_size, -1, -1])
)
return enc_streaming_mask
def forward_embeddings(
@@ -517,12 +541,24 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
xs_pad: torch.Tensor,
masks: torch.Tensor,
chunk_size_nc: Optional[Union[int, list[int]]] = None,
left_chunk_nc: Optional[Union[int, list[int]]] = None
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor],
tuple[torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
torch.Tensor]]:
left_chunk_nc: Optional[Union[int, list[int]]] = None,
) -> Union[
tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
torch.Tensor,
],
tuple[
torch.Tensor,
Optional[torch.Tensor],
Optional[torch.Tensor],
torch.Tensor,
torch.Tensor,
torch.Tensor,
],
]:
"""Forwarding the inputs through the top embedding layers
Args:
@@ -530,7 +566,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
input tensor
masks: torch.Tensor
input mask
chunk_size_nc: (optional, default is None) chunk size for
chunk_size_nc: (optional, default is None) chunk size for
non-causal layers
left_chunk_nc: (optional, default is None) # of left chunks for
non-causal layers
@@ -543,21 +579,21 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
f"""The sequence length after time reduction is invalid:
{seq_len}. Your input feature is too short. Consider
filtering out the very short sentence from data
loader""", )
loader""",
)
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(seq_len, batch_size,
self.chunk_size,
self.left_chunk)
enc_streaming_mask = self._streaming_mask(
seq_len, batch_size, self.chunk_size, self.left_chunk
)
if xs_pad.is_cuda:
enc_streaming_mask = enc_streaming_mask.cuda()
xs_pad = xs_pad.cuda()
input_tensor = xs_pad
input_tensor, masks = self._forward_embeddings_core(
input_tensor, masks)
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
streaming_mask = enc_streaming_mask
if streaming_mask is not None and masks is not None:
@@ -569,7 +605,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
if chunk_size_nc is not None:
enc_streaming_mask_nc = self._streaming_mask(
seq_len, batch_size, chunk_size_nc, left_chunk_nc)
seq_len, batch_size, chunk_size_nc, left_chunk_nc
)
if xs_pad.is_cuda:
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
if masks is not None:
@@ -622,8 +659,8 @@ class ConformerEncoder(TransformerEncoderBase):
left_chunk = 6
left_chunk = [12, 9, 6, 3]
num_lang: int
This parameter is used to store the number of languages in the
lang_dict, only used for multiseed/multilingual models.
This parameter is used to store the number of languages in the
lang_dict, only used for multiseed/multilingual models.
default None.
attention_dim: int, optional
attention dimension. default 256.
@@ -721,16 +758,16 @@ class ConformerEncoder(TransformerEncoderBase):
extra_layer_output_idx: int
the layer index to be exposed.
relative_attention_bias_args: dict, optional
use more efficient scalar bias-based relative multihead attention
use more efficient scalar bias-based relative multihead attention
(Q*K^T + B) implemented in cmb.basics.embedding.
[T5/ALiBi]RelativeAttentionLogitBias
usage: relative_attention_bias_args={"type": t5/alibi}
additional method-specific arguments can be provided (see
additional method-specific arguments can be provided (see
transformer_base.py)
time_reduction: int optional
time reduction factor
default 4
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
dot product attention in training.
Default: False
nemo_conv_settings: dict, optional
@@ -748,12 +785,12 @@ class ConformerEncoder(TransformerEncoderBase):
Add extra padding in conv2d subsampling layers. Choices are
(feat, feat_time, none, True)
Default: none
replication_pad_for_subsample_embedding: For batched-streaming
replication_pad_for_subsample_embedding: For batched-streaming
decoding, use "replication" padding for the cache at start of
utterance.
Default: False
attention_group_size: int, optional
the number of groups to use for attention, default 1
the number of groups to use for attention, default 1
(Multi-Head Attention),
1 = typical Multi-Head Attention,
1 < attention_group_size < attention_heads = Grouped-Query
@@ -799,8 +836,7 @@ class ConformerEncoder(TransformerEncoderBase):
time_reduction: int = 4,
use_pt_scaled_dot_product_attention: bool = False,
nemo_conv_settings: Optional[dict[str, Any]] = None,
conv2d_extra_padding: Literal["feat", "feat_time", "none",
True] = "none",
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
replication_pad_for_subsample_embedding: bool = False,
attention_group_size: int = 1,
encoder_embedding_config: Optional[dict[str, Any]] = None,
@@ -827,39 +863,43 @@ class ConformerEncoder(TransformerEncoderBase):
self.num_lang = num_lang
self.kernel_size = kernel_size
self.replication_pad_for_subsample_embedding: bool = (
replication_pad_for_subsample_embedding)
assert (self.num_heads % attention_group_size == 0
), "attention_group_size must divide n_head"
replication_pad_for_subsample_embedding
)
assert self.num_heads % attention_group_size == 0, (
"attention_group_size must divide n_head"
)
self.num_heads_k = self.num_heads // attention_group_size
self.encoders = MultiSequential(*[
ConformerEncoderLayer(
d_model=attention_dim,
ext_pw_out_channel=ext_pw_out_channel,
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
depthwise_multiplier=depthwise_multiplier,
n_head=attention_heads,
d_ffn=linear_units,
ext_pw_kernel_size=ext_pw_kernel_size,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
causal=causal,
batch_norm=batch_norm,
activation=activation,
chunk_se=chunk_se,
chunk_size=chunk_size,
conv_activation=conv_activation,
conv_glu_type=conv_glu_type,
bias_in_glu=bias_in_glu,
linear_glu_in_convm=linear_glu_in_convm,
attention_glu_type=attention_glu_type,
activation_checkpointing=activation_checkpointing,
export=export,
use_pt_scaled_dot_product_attention=
use_pt_scaled_dot_product_attention,
attn_group_sizes=attention_group_size,
) for _ in range(num_blocks)
])
self.encoders = MultiSequential(
*[
ConformerEncoderLayer(
d_model=attention_dim,
ext_pw_out_channel=ext_pw_out_channel,
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
depthwise_multiplier=depthwise_multiplier,
n_head=attention_heads,
d_ffn=linear_units,
ext_pw_kernel_size=ext_pw_kernel_size,
kernel_size=kernel_size,
dropout_rate=dropout_rate,
causal=causal,
batch_norm=batch_norm,
activation=activation,
chunk_se=chunk_se,
chunk_size=chunk_size,
conv_activation=conv_activation,
conv_glu_type=conv_glu_type,
bias_in_glu=bias_in_glu,
linear_glu_in_convm=linear_glu_in_convm,
attention_glu_type=attention_glu_type,
activation_checkpointing=activation_checkpointing,
export=export,
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
attn_group_sizes=attention_group_size,
)
for _ in range(num_blocks)
]
)
self.extra_layer_output_idx = extra_layer_output_idx
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
# Make a zeros scalar we can use in get_initial_state to determine
@@ -867,34 +907,36 @@ class ConformerEncoder(TransformerEncoderBase):
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
def init_relative_attention_bias(
self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]:
self, input_tensor: torch.Tensor
) -> Optional[torch.Tensor]:
if self.relative_attention_bias_layer:
return self.relative_attention_bias_layer(input_tensor)
def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device,
mask: Optional[torch.Tensor]) -> torch.Tensor:
def calculate_hs_mask(
self, xs_pad: torch.Tensor, device: torch.device, mask: Optional[torch.Tensor]
) -> torch.Tensor:
max_audio_length = xs_pad.shape[1]
batch_size = xs_pad.shape[0]
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
self.chunk_size,
self.left_chunk)
enc_streaming_mask = self._streaming_mask(
max_audio_length, batch_size, self.chunk_size, self.left_chunk
)
enc_streaming_mask = enc_streaming_mask.to(device)
if mask is None:
return enc_streaming_mask
feature_lens = mask.sum(1)
padding_length = feature_lens
pad_mask = (torch.arange(0, max_audio_length,
device=device).expand(padding_length.size(0),
-1)
< padding_length.unsqueeze(1))
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
padding_length.size(0), -1
) < padding_length.unsqueeze(1)
pad_mask = pad_mask.unsqueeze(1)
pad_mask = pad_mask & enc_streaming_mask
return pad_mask
@torch.jit.ignore
def forward(self, xs_pad: torch.Tensor,
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def forward(
self, xs_pad: torch.Tensor, masks: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
"""Conformer Forward function
Args:
@@ -905,11 +947,12 @@ class ConformerEncoder(TransformerEncoderBase):
"""
xs_pad = self.encoder_embedding(xs_pad)
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
xs_pad, masks)
xs_pad, masks
)
unfolded = False
ori_bz, seq_len, D = input_tensor.shape
max_seq_len = 500 #maximum position for absolute positional encoding
max_seq_len = 500 # maximum position for absolute positional encoding
if seq_len > max_seq_len:
# audio sequence is longer than max_seq_len, unfold it into chunks
# of max_seq_len
@@ -921,26 +964,29 @@ class ConformerEncoder(TransformerEncoderBase):
else:
chunk_pad_size = 0
if chunk_pad_size > 0:
input_tensor_pad = F.pad(input_tensor,
(0, 0, 0, chunk_pad_size), "constant",
0)
input_tensor_pad = F.pad(
input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0
)
input_tensor = input_tensor_pad.to(input_tensor.device)
input_tensor = unfold_tensor(input_tensor, max_seq_len)
if masks is not None:
# revise hs_mask here because the previous calculated hs_mask
# did not consider extra pad
subsampled_pad_mask = masks.squeeze(
1) # [bz, subsampled_unmask_seq_len]
1
) # [bz, subsampled_unmask_seq_len]
extra_padded_subsamlped_pad_mask = F.pad(
subsampled_pad_mask, (0, chunk_pad_size), "constant",
False) # extra padding to the pad mask
extra_padded_subsamlped_pad_mask = \
subsampled_pad_mask, (0, chunk_pad_size), "constant", False
) # extra padding to the pad mask
extra_padded_subsamlped_pad_mask = (
extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
)
masks_unfold = unfold_tensor(
extra_padded_subsamlped_pad_mask, max_seq_len
) # unfold the pad mask like we did to the input tensor
masks_unfold = masks_unfold.squeeze(
-1).bool() # unfold op does not support bool tensor
-1
).bool() # unfold op does not support bool tensor
else:
masks_unfold = None
hs_mask = self.calculate_hs_mask(
@@ -949,15 +995,14 @@ class ConformerEncoder(TransformerEncoderBase):
# layer_emb = None
relative_attention_bias = self.init_relative_attention_bias(
input_tensor)
relative_attention_bias = self.init_relative_attention_bias(input_tensor)
_simplified_path = (self.extra_layer_output_idx == -1
and relative_attention_bias is None)
_simplified_path = (
self.extra_layer_output_idx == -1 and relative_attention_bias is None
)
if _simplified_path:
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v,
hs_mask)
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
else:
for i, layer in enumerate(self.encoders):
input_tensor, _, _, _ = layer(
@@ -997,28 +1042,32 @@ class WindowQformer(nn.Module):
):
super().__init__()
self.decoders = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=attention_dim,
nhead=attention_heads,
dim_feedforward=linear_units,
dropout=dropout_rate,
activation="relu",
batch_first=True,
norm_first=normalize_before, # TODO need to verify
) for _ in range(num_blocks)
])
self.decoders = nn.ModuleList(
[
nn.TransformerDecoderLayer(
d_model=attention_dim,
nhead=attention_heads,
dim_feedforward=linear_units,
dropout=dropout_rate,
activation="relu",
batch_first=True,
norm_first=normalize_before, # TODO need to verify
)
for _ in range(num_blocks)
]
)
self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
if normalize_before else None)
self.after_norm = (
nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None
)
self.window_size = window_size
def forward(
self,
audio_embed: torch.Tensor,
mask: Optional[torch.Tensor],
embed_len: Optional[int] = None
self,
audio_embed: torch.Tensor,
mask: Optional[torch.Tensor],
embed_len: Optional[int] = None,
) -> tuple[torch.Tensor, Optional[int]]:
"""forward decoder"""
# audio_embed: N x T x D => N x D x T
@@ -1027,8 +1076,9 @@ class WindowQformer(nn.Module):
# audio_embed: N x D x 1 x T => N x DK x T'
padding = audio_embed.shape[-1] % self.window_size
if padding > 0:
audio_embed = F.pad(audio_embed, (0, self.window_size - padding),
"constant", 0)
audio_embed = F.pad(
audio_embed, (0, self.window_size - padding), "constant", 0
)
embed_chunk = F.unfold(
audio_embed[..., None, :],
@@ -1045,10 +1095,7 @@ class WindowQformer(nn.Module):
# NT' x 1 x D
q = self.queries.expand(bsz * slen, -1, -1)
for layer in self.decoders:
q = layer(tgt=q,
memory=embed_chunk,
tgt_mask=None,
memory_mask=mask)
q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask)
if self.after_norm is not None:
q = self.after_norm(q)
@@ -1068,8 +1115,7 @@ class AudioEmbedding(nn.Module):
super().__init__()
self.config = config
# n_embed or hidden_size for text LM
hidden_size = (config.n_embd
if hasattr(config, "n_embd") else config.hidden_size)
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
@@ -1078,8 +1124,10 @@ class AudioEmbedding(nn.Module):
)
self.layer_idx = -2
if (isinstance(config.audio_processor, dict)
and config.audio_processor.get("name", None) == "cascades"):
if (
isinstance(config.audio_processor, dict)
and config.audio_processor.get("name", None) == "cascades"
):
encoder_config = config.audio_processor.get("config", None)
assert encoder_config is not None
self.encoder = ConformerEncoder(**encoder_config)
@@ -1089,13 +1137,11 @@ class AudioEmbedding(nn.Module):
else:
raise NotImplementedError("")
assert (audio_dim_out
is not None), "Remember to set values for audio_dim_out"
assert audio_dim_out is not None, "Remember to set values for audio_dim_out"
self.audio_dim_out = audio_dim_out
self.audio_dim_in = n_mels
self.freeze_audio_processor = kwargs.get("freeze_audio_processor",
False)
self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False)
self.downsample_rate = kwargs.get("downsample_rate", 1)
@@ -1107,8 +1153,9 @@ class AudioEmbedding(nn.Module):
self.qformer = None
if kwargs.get("use_conv_downsample", False):
assert (self.qformer is None
), "don't support use qformer and conv downsample together"
assert self.qformer is None, (
"don't support use qformer and conv downsample together"
)
nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
default_nemo_conv_settings = {
"subsampling": "dw_striding",
@@ -1124,11 +1171,13 @@ class AudioEmbedding(nn.Module):
if nemo_conv_settings:
default_nemo_conv_settings.update(nemo_conv_settings)
for i in ["subsampling_factor", "feat_in", "feat_out"]:
assert (
i not in nemo_conv_settings
), "{i} should be specified outside of the NeMo dictionary"
assert i not in nemo_conv_settings, (
"{i} should be specified outside of the NeMo dictionary"
)
self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, )
self.conv_ds = NemoConvSubsampling(
**default_nemo_conv_settings,
)
else:
self.conv_ds = None
@@ -1140,30 +1189,26 @@ class AudioEmbedding(nn.Module):
# (do not use image_projection and image_proj_norm)
dim_projection = hidden_size
depth = 2
self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds)
else self.downsample_rate)
self.linear_downsample_rate = (
1 if (self.qformer or self.conv_ds) else self.downsample_rate
)
layers = [
nn.Linear(audio_dim_out * self.linear_downsample_rate,
dim_projection)
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
]
for _ in range(1, depth):
layers.extend(
[nn.GELU(),
nn.Linear(dim_projection, dim_projection)])
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
self.audio_projection = nn.Sequential(*layers)
# NOTE vision-speech tasks use a separate projection layer
layers = [
nn.Linear(audio_dim_out * self.linear_downsample_rate,
dim_projection)
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
]
for _ in range(1, depth):
layers.extend(
[nn.GELU(),
nn.Linear(dim_projection, dim_projection)])
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
self.audio_projection_for_vision = nn.Sequential(*layers)
else:
raise NotImplementedError(
f"projection_cls = {projection_cls}, not implemented")
f"projection_cls = {projection_cls}, not implemented"
)
# TODO: audio sequence compression - Qformer
self.vocab_size = config.vocab_size
@@ -1188,11 +1233,9 @@ class AudioEmbedding(nn.Module):
"""
if self.freeze_audio_processor:
with torch.no_grad():
audio_features, masks = self.encoder(input_embeds,
audio_attention_mask)
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
else:
audio_features, masks = self.encoder(input_embeds,
audio_attention_mask)
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
if self.qformer is not None:
audio_features, _ = self.qformer(audio_features, mask=None)
@@ -1221,14 +1264,13 @@ class AudioEmbedding(nn.Module):
feat_dim * self.linear_downsample_rate,
)
if audio_projection_mode == 'speech':
if audio_projection_mode == "speech":
audio_set_tensor = self.audio_projection(audio_features)
elif audio_projection_mode == 'vision':
elif audio_projection_mode == "vision":
audio_set_tensor = self.audio_projection_for_vision(audio_features)
else:
raise ValueError(
f"audio_projection_mode = {audio_projection_mode} not "\
"implemented"
f"audio_projection_mode = {audio_projection_mode} not implemented"
)
return audio_set_tensor
@@ -1242,7 +1284,7 @@ class AudioEmbedding(nn.Module):
"""
arguments:
audio_features: audio features (T, D)
returns:
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
"""