Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -14,15 +14,24 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
|
||||
CheckpointWrapper)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
FullyShardedDataParallel)
|
||||
CheckpointWrapper,
|
||||
)
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.models.phi4mm_utils import (
|
||||
AbsolutePositionalEncoding, ConvModule, FeedForward, MeanVarianceNormLayer,
|
||||
MultiHeadedAttention, MultiSequential, NemoConvSubsampling,
|
||||
T5RelativeAttentionLogitBias, adaptive_enc_mask, get_offset, unfold_tensor)
|
||||
AbsolutePositionalEncoding,
|
||||
ConvModule,
|
||||
FeedForward,
|
||||
MeanVarianceNormLayer,
|
||||
MultiHeadedAttention,
|
||||
MultiSequential,
|
||||
NemoConvSubsampling,
|
||||
T5RelativeAttentionLogitBias,
|
||||
adaptive_enc_mask,
|
||||
get_offset,
|
||||
unfold_tensor,
|
||||
)
|
||||
|
||||
_AUDIO_PLACEHOLDER_TOKEN_ID = 200011 # <|endoftext11|>
|
||||
|
||||
@@ -40,9 +49,9 @@ class ConformerEncoderLayer(nn.Module):
|
||||
if > 0, ext_pw_out_channel is a dim channel size
|
||||
for the last pointwise conv after swish activation.
|
||||
depthwise_seperable_out_channel: int
|
||||
if set different to 0, the number of
|
||||
if set different to 0, the number of
|
||||
depthwise_seperable_out_channel will be used as a
|
||||
channel_out of the second conv1d layer.
|
||||
channel_out of the second conv1d layer.
|
||||
otherwise, it equals to 0, the second conv1d layer is skipped.
|
||||
depthwise_multiplier: int
|
||||
number of input_dim channels duplication. this value
|
||||
@@ -119,10 +128,10 @@ class ConformerEncoderLayer(nn.Module):
|
||||
and allow the onnx conversion for inference.
|
||||
default False.
|
||||
use_pt_scaled_dot_product_attention: bool, optional
|
||||
if set to True, use pytorch's scaled dot product attention
|
||||
if set to True, use pytorch's scaled dot product attention
|
||||
implementation in training.
|
||||
attn_group_sizes: int, optional
|
||||
the number of groups to use for attention, default 1
|
||||
the number of groups to use for attention, default 1
|
||||
(Multi-Head Attention),
|
||||
1 = typical Multi-Head Attention,
|
||||
1 < attn_group_sizes < attention_heads = Grouped-Query Attention
|
||||
@@ -173,8 +182,7 @@ class ConformerEncoderLayer(nn.Module):
|
||||
attention_inner_dim,
|
||||
attention_glu_type,
|
||||
bias_in_glu,
|
||||
use_pt_scaled_dot_product_attention=
|
||||
use_pt_scaled_dot_product_attention,
|
||||
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
||||
group_size=attn_group_sizes,
|
||||
)
|
||||
self.conv = ConvModule(
|
||||
@@ -296,7 +304,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
(Q*K^T + B) implemented in cmb.basics.embedding.
|
||||
[T5/ALiBi]RelativeAttentionLogitBias
|
||||
usage: relative_attention_bias_args={"type": t5/alibi}
|
||||
additional method-specific arguments can be provided (see
|
||||
additional method-specific arguments can be provided (see
|
||||
transformer_base.py)
|
||||
positional_dropout_rate: float, optional
|
||||
dropout rate after positional encoding. default 0.0
|
||||
@@ -310,10 +318,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
supraframe utts in batch.
|
||||
Default: none
|
||||
attention_group_size: int, optional
|
||||
the number of groups to use for attention, default 1
|
||||
the number of groups to use for attention, default 1
|
||||
(Multi-Head Attention),
|
||||
1 = typical Multi-Head Attention,
|
||||
1 < attention_group_size < attention_heads = Grouped-Query
|
||||
1 < attention_group_size < attention_heads = Grouped-Query
|
||||
Attention
|
||||
attention_group_size = attention_heads = Multi-Query Attention
|
||||
"""
|
||||
@@ -334,8 +342,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
relative_attention_bias_args: Optional[dict[str, Any]] = None,
|
||||
positional_dropout_rate: float = 0.0,
|
||||
nemo_conv_settings: Optional[dict[str, Any]] = None,
|
||||
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
||||
True] = "none",
|
||||
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
||||
attention_group_size: int = 1,
|
||||
encoder_embedding_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
@@ -366,70 +373,77 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
if nemo_conv_settings:
|
||||
default_nemo_conv_settings.update(nemo_conv_settings)
|
||||
for i in ["subsampling_factor", "feat_in", "feat_out"]:
|
||||
assert (
|
||||
i not in nemo_conv_settings
|
||||
), "{i} should be specified outside of the NeMo dictionary"
|
||||
assert i not in nemo_conv_settings, (
|
||||
"{i} should be specified outside of the NeMo dictionary"
|
||||
)
|
||||
|
||||
self.embed = NemoConvSubsampling(**default_nemo_conv_settings, )
|
||||
self.embed = NemoConvSubsampling(
|
||||
**default_nemo_conv_settings,
|
||||
)
|
||||
else:
|
||||
raise ValueError("unknown input_layer: " + input_layer)
|
||||
|
||||
self.pos_emb = AbsolutePositionalEncoding(attention_dim,
|
||||
positional_dropout_rate)
|
||||
self.pos_emb = AbsolutePositionalEncoding(
|
||||
attention_dim, positional_dropout_rate
|
||||
)
|
||||
|
||||
self.relative_attention_bias_type = (
|
||||
relative_attention_bias_args.get("type")
|
||||
if relative_attention_bias_args else None)
|
||||
if relative_attention_bias_args
|
||||
else None
|
||||
)
|
||||
if self.relative_attention_bias_type == "t5":
|
||||
assert (self.num_heads % self.attention_group_size == 0
|
||||
), "attention_group_size must divide n_head"
|
||||
assert self.num_heads % self.attention_group_size == 0, (
|
||||
"attention_group_size must divide n_head"
|
||||
)
|
||||
self.relative_attention_bias_layer = T5RelativeAttentionLogitBias(
|
||||
self.num_heads // self.attention_group_size,
|
||||
max_distance=relative_attention_bias_args.get(
|
||||
"t5_bias_max_distance", 1000),
|
||||
symmetric=relative_attention_bias_args.get(
|
||||
"t5_bias_symmetric", False),
|
||||
"t5_bias_max_distance", 1000
|
||||
),
|
||||
symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
self.encoder_embedding = MeanVarianceNormLayer(
|
||||
self.encoder_embedding_config["input_size"])
|
||||
self.encoder_embedding_config["input_size"]
|
||||
)
|
||||
|
||||
def compute_lens_change(
|
||||
self,
|
||||
feature_lens: Union[int,
|
||||
torch.Tensor]) -> Union[int, torch.Tensor]:
|
||||
self, feature_lens: Union[int, torch.Tensor]
|
||||
) -> Union[int, torch.Tensor]:
|
||||
"""feature_lens: int
|
||||
return updated feature lens.
|
||||
|
||||
This used to return a different lambda function for each case that
|
||||
computed the right thing. That does not work within Torchscript.
|
||||
This used to return a different lambda function for each case that
|
||||
computed the right thing. That does not work within Torchscript.
|
||||
If you really need this to be faster, create nn.Module()-s for all
|
||||
the cases and return one of them. Torchscript does support that.
|
||||
"""
|
||||
if self.input_layer == "nemo_conv":
|
||||
# Handle the special causal case
|
||||
subsampling_causal_cond = self.nemo_conv_settings.get(
|
||||
"subsampling", "dw_striding") in [
|
||||
"dw_striding",
|
||||
"striding",
|
||||
"striding_conv1d",
|
||||
]
|
||||
"subsampling", "dw_striding"
|
||||
) in [
|
||||
"dw_striding",
|
||||
"striding",
|
||||
"striding_conv1d",
|
||||
]
|
||||
is_causal = self.nemo_conv_settings.get("is_causal", False)
|
||||
if is_causal and subsampling_causal_cond:
|
||||
lens_change = (torch.ceil(feature_lens /
|
||||
self.time_reduction).long()
|
||||
if isinstance(feature_lens, Tensor) else
|
||||
math.ceil(feature_lens / self.time_reduction))
|
||||
lens_change = (
|
||||
torch.ceil(feature_lens / self.time_reduction).long()
|
||||
if isinstance(feature_lens, Tensor)
|
||||
else math.ceil(feature_lens / self.time_reduction)
|
||||
)
|
||||
feature_lens_remainder = feature_lens % self.time_reduction
|
||||
if isinstance(feature_lens, Tensor):
|
||||
lens_change[feature_lens_remainder != 1] += 1
|
||||
elif feature_lens_remainder != 1:
|
||||
lens_change += 1
|
||||
return lens_change
|
||||
ceil_func = (math.ceil
|
||||
if isinstance(feature_lens, int) else torch.ceil)
|
||||
ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil
|
||||
return ceil_func(feature_lens / self.time_reduction)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -437,10 +451,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
"""Abstract forward method implementation."""
|
||||
|
||||
def _chunk_size_selection(
|
||||
self,
|
||||
chunk_size: Optional[Union[int, list[int]]] = None,
|
||||
left_chunk: Optional[Union[int,
|
||||
list[int]]] = None) -> tuple[int, int]:
|
||||
self,
|
||||
chunk_size: Optional[Union[int, list[int]]] = None,
|
||||
left_chunk: Optional[Union[int, list[int]]] = None,
|
||||
) -> tuple[int, int]:
|
||||
"""If chunk size is a list, we will randomly select a chunk size."""
|
||||
|
||||
if chunk_size is None:
|
||||
@@ -450,15 +464,16 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
if isinstance(chunk_size, list):
|
||||
# Variable chunk size during training
|
||||
chunk_size_index = int(
|
||||
torch.randint(low=0, high=len(chunk_size), size=(1, )))
|
||||
torch.randint(low=0, high=len(chunk_size), size=(1,))
|
||||
)
|
||||
chunk_size_train_eff = chunk_size[chunk_size_index]
|
||||
if not isinstance(left_chunk, list):
|
||||
raise ValueError(
|
||||
"Since chunk_size is a list, left_chunk must be a list")
|
||||
"Since chunk_size is a list, left_chunk must be a list"
|
||||
)
|
||||
if len(left_chunk) != len(chunk_size):
|
||||
raise ValueError(
|
||||
"The length of left_chunk must be the same as length of "\
|
||||
"chunk_size."
|
||||
"The length of left_chunk must be the same as length of chunk_size."
|
||||
)
|
||||
left_chunk_train_eff = left_chunk[chunk_size_index]
|
||||
else:
|
||||
@@ -479,8 +494,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
return embed_class
|
||||
|
||||
def _forward_embeddings_core(
|
||||
self, input_tensor: torch.Tensor,
|
||||
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self, input_tensor: torch.Tensor, masks: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
embed_class = self._get_embed_class(self.embed)
|
||||
assert isinstance(embed_class, NemoConvSubsampling)
|
||||
input_tensor, masks = self.embed(input_tensor, masks)
|
||||
@@ -493,23 +508,32 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
pos_v = None
|
||||
if self.relative_attention_bias_layer is None:
|
||||
input_tensor = self.pos_emb(
|
||||
input_tensor) # default to add abs sinusoid embedding
|
||||
input_tensor
|
||||
) # default to add abs sinusoid embedding
|
||||
return pos_k, pos_v
|
||||
|
||||
def _streaming_mask(self, seq_len: int, batch_size: int,
|
||||
chunk_size: Union[int, list[int]],
|
||||
left_chunk: Union[int, list[int]]) -> torch.Tensor:
|
||||
chunk_size_train_eff, left_chunk_train_eff = \
|
||||
self._chunk_size_selection(chunk_size, left_chunk)
|
||||
def _streaming_mask(
|
||||
self,
|
||||
seq_len: int,
|
||||
batch_size: int,
|
||||
chunk_size: Union[int, list[int]],
|
||||
left_chunk: Union[int, list[int]],
|
||||
) -> torch.Tensor:
|
||||
chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(
|
||||
chunk_size, left_chunk
|
||||
)
|
||||
|
||||
# Create mask matrix for streaming
|
||||
# S stores start index. if chunksize is 18, s is [0,18,36,....]
|
||||
chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff)
|
||||
|
||||
enc_streaming_mask = (adaptive_enc_mask(
|
||||
seq_len, chunk_start_idx,
|
||||
left_window=left_chunk_train_eff).unsqueeze(0).expand(
|
||||
[batch_size, -1, -1]))
|
||||
enc_streaming_mask = (
|
||||
adaptive_enc_mask(
|
||||
seq_len, chunk_start_idx, left_window=left_chunk_train_eff
|
||||
)
|
||||
.unsqueeze(0)
|
||||
.expand([batch_size, -1, -1])
|
||||
)
|
||||
return enc_streaming_mask
|
||||
|
||||
def forward_embeddings(
|
||||
@@ -517,12 +541,24 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
xs_pad: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
chunk_size_nc: Optional[Union[int, list[int]]] = None,
|
||||
left_chunk_nc: Optional[Union[int, list[int]]] = None
|
||||
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor],
|
||||
tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]]:
|
||||
left_chunk_nc: Optional[Union[int, list[int]]] = None,
|
||||
) -> Union[
|
||||
tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
],
|
||||
tuple[
|
||||
torch.Tensor,
|
||||
Optional[torch.Tensor],
|
||||
Optional[torch.Tensor],
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor,
|
||||
],
|
||||
]:
|
||||
"""Forwarding the inputs through the top embedding layers
|
||||
|
||||
Args:
|
||||
@@ -530,7 +566,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
input tensor
|
||||
masks: torch.Tensor
|
||||
input mask
|
||||
chunk_size_nc: (optional, default is None) chunk size for
|
||||
chunk_size_nc: (optional, default is None) chunk size for
|
||||
non-causal layers
|
||||
left_chunk_nc: (optional, default is None) # of left chunks for
|
||||
non-causal layers
|
||||
@@ -543,21 +579,21 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
f"""The sequence length after time reduction is invalid:
|
||||
{seq_len}. Your input feature is too short. Consider
|
||||
filtering out the very short sentence from data
|
||||
loader""", )
|
||||
loader""",
|
||||
)
|
||||
|
||||
batch_size = xs_pad.shape[0]
|
||||
|
||||
enc_streaming_mask = self._streaming_mask(seq_len, batch_size,
|
||||
self.chunk_size,
|
||||
self.left_chunk)
|
||||
enc_streaming_mask = self._streaming_mask(
|
||||
seq_len, batch_size, self.chunk_size, self.left_chunk
|
||||
)
|
||||
|
||||
if xs_pad.is_cuda:
|
||||
enc_streaming_mask = enc_streaming_mask.cuda()
|
||||
xs_pad = xs_pad.cuda()
|
||||
|
||||
input_tensor = xs_pad
|
||||
input_tensor, masks = self._forward_embeddings_core(
|
||||
input_tensor, masks)
|
||||
input_tensor, masks = self._forward_embeddings_core(input_tensor, masks)
|
||||
|
||||
streaming_mask = enc_streaming_mask
|
||||
if streaming_mask is not None and masks is not None:
|
||||
@@ -569,7 +605,8 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
|
||||
if chunk_size_nc is not None:
|
||||
enc_streaming_mask_nc = self._streaming_mask(
|
||||
seq_len, batch_size, chunk_size_nc, left_chunk_nc)
|
||||
seq_len, batch_size, chunk_size_nc, left_chunk_nc
|
||||
)
|
||||
if xs_pad.is_cuda:
|
||||
enc_streaming_mask_nc = enc_streaming_mask_nc.cuda()
|
||||
if masks is not None:
|
||||
@@ -622,8 +659,8 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
left_chunk = 6
|
||||
left_chunk = [12, 9, 6, 3]
|
||||
num_lang: int
|
||||
This parameter is used to store the number of languages in the
|
||||
lang_dict, only used for multiseed/multilingual models.
|
||||
This parameter is used to store the number of languages in the
|
||||
lang_dict, only used for multiseed/multilingual models.
|
||||
default None.
|
||||
attention_dim: int, optional
|
||||
attention dimension. default 256.
|
||||
@@ -721,16 +758,16 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
extra_layer_output_idx: int
|
||||
the layer index to be exposed.
|
||||
relative_attention_bias_args: dict, optional
|
||||
use more efficient scalar bias-based relative multihead attention
|
||||
use more efficient scalar bias-based relative multihead attention
|
||||
(Q*K^T + B) implemented in cmb.basics.embedding.
|
||||
[T5/ALiBi]RelativeAttentionLogitBias
|
||||
usage: relative_attention_bias_args={"type": t5/alibi}
|
||||
additional method-specific arguments can be provided (see
|
||||
additional method-specific arguments can be provided (see
|
||||
transformer_base.py)
|
||||
time_reduction: int optional
|
||||
time reduction factor
|
||||
default 4
|
||||
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
|
||||
use_pt_scaled_dot_product_attention: whether to use pytorch scaled
|
||||
dot product attention in training.
|
||||
Default: False
|
||||
nemo_conv_settings: dict, optional
|
||||
@@ -748,12 +785,12 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
Add extra padding in conv2d subsampling layers. Choices are
|
||||
(feat, feat_time, none, True)
|
||||
Default: none
|
||||
replication_pad_for_subsample_embedding: For batched-streaming
|
||||
replication_pad_for_subsample_embedding: For batched-streaming
|
||||
decoding, use "replication" padding for the cache at start of
|
||||
utterance.
|
||||
Default: False
|
||||
attention_group_size: int, optional
|
||||
the number of groups to use for attention, default 1
|
||||
the number of groups to use for attention, default 1
|
||||
(Multi-Head Attention),
|
||||
1 = typical Multi-Head Attention,
|
||||
1 < attention_group_size < attention_heads = Grouped-Query
|
||||
@@ -799,8 +836,7 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
time_reduction: int = 4,
|
||||
use_pt_scaled_dot_product_attention: bool = False,
|
||||
nemo_conv_settings: Optional[dict[str, Any]] = None,
|
||||
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
||||
True] = "none",
|
||||
conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none",
|
||||
replication_pad_for_subsample_embedding: bool = False,
|
||||
attention_group_size: int = 1,
|
||||
encoder_embedding_config: Optional[dict[str, Any]] = None,
|
||||
@@ -827,39 +863,43 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
self.num_lang = num_lang
|
||||
self.kernel_size = kernel_size
|
||||
self.replication_pad_for_subsample_embedding: bool = (
|
||||
replication_pad_for_subsample_embedding)
|
||||
assert (self.num_heads % attention_group_size == 0
|
||||
), "attention_group_size must divide n_head"
|
||||
replication_pad_for_subsample_embedding
|
||||
)
|
||||
assert self.num_heads % attention_group_size == 0, (
|
||||
"attention_group_size must divide n_head"
|
||||
)
|
||||
self.num_heads_k = self.num_heads // attention_group_size
|
||||
|
||||
self.encoders = MultiSequential(*[
|
||||
ConformerEncoderLayer(
|
||||
d_model=attention_dim,
|
||||
ext_pw_out_channel=ext_pw_out_channel,
|
||||
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
||||
depthwise_multiplier=depthwise_multiplier,
|
||||
n_head=attention_heads,
|
||||
d_ffn=linear_units,
|
||||
ext_pw_kernel_size=ext_pw_kernel_size,
|
||||
kernel_size=kernel_size,
|
||||
dropout_rate=dropout_rate,
|
||||
causal=causal,
|
||||
batch_norm=batch_norm,
|
||||
activation=activation,
|
||||
chunk_se=chunk_se,
|
||||
chunk_size=chunk_size,
|
||||
conv_activation=conv_activation,
|
||||
conv_glu_type=conv_glu_type,
|
||||
bias_in_glu=bias_in_glu,
|
||||
linear_glu_in_convm=linear_glu_in_convm,
|
||||
attention_glu_type=attention_glu_type,
|
||||
activation_checkpointing=activation_checkpointing,
|
||||
export=export,
|
||||
use_pt_scaled_dot_product_attention=
|
||||
use_pt_scaled_dot_product_attention,
|
||||
attn_group_sizes=attention_group_size,
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
self.encoders = MultiSequential(
|
||||
*[
|
||||
ConformerEncoderLayer(
|
||||
d_model=attention_dim,
|
||||
ext_pw_out_channel=ext_pw_out_channel,
|
||||
depthwise_seperable_out_channel=depthwise_seperable_out_channel,
|
||||
depthwise_multiplier=depthwise_multiplier,
|
||||
n_head=attention_heads,
|
||||
d_ffn=linear_units,
|
||||
ext_pw_kernel_size=ext_pw_kernel_size,
|
||||
kernel_size=kernel_size,
|
||||
dropout_rate=dropout_rate,
|
||||
causal=causal,
|
||||
batch_norm=batch_norm,
|
||||
activation=activation,
|
||||
chunk_se=chunk_se,
|
||||
chunk_size=chunk_size,
|
||||
conv_activation=conv_activation,
|
||||
conv_glu_type=conv_glu_type,
|
||||
bias_in_glu=bias_in_glu,
|
||||
linear_glu_in_convm=linear_glu_in_convm,
|
||||
attention_glu_type=attention_glu_type,
|
||||
activation_checkpointing=activation_checkpointing,
|
||||
export=export,
|
||||
use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention,
|
||||
attn_group_sizes=attention_group_size,
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
self.extra_layer_output_idx = extra_layer_output_idx
|
||||
self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs
|
||||
# Make a zeros scalar we can use in get_initial_state to determine
|
||||
@@ -867,34 +907,36 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
|
||||
|
||||
def init_relative_attention_bias(
|
||||
self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
self, input_tensor: torch.Tensor
|
||||
) -> Optional[torch.Tensor]:
|
||||
if self.relative_attention_bias_layer:
|
||||
return self.relative_attention_bias_layer(input_tensor)
|
||||
|
||||
def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device,
|
||||
mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def calculate_hs_mask(
|
||||
self, xs_pad: torch.Tensor, device: torch.device, mask: Optional[torch.Tensor]
|
||||
) -> torch.Tensor:
|
||||
max_audio_length = xs_pad.shape[1]
|
||||
batch_size = xs_pad.shape[0]
|
||||
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
|
||||
self.chunk_size,
|
||||
self.left_chunk)
|
||||
enc_streaming_mask = self._streaming_mask(
|
||||
max_audio_length, batch_size, self.chunk_size, self.left_chunk
|
||||
)
|
||||
enc_streaming_mask = enc_streaming_mask.to(device)
|
||||
if mask is None:
|
||||
return enc_streaming_mask
|
||||
|
||||
feature_lens = mask.sum(1)
|
||||
padding_length = feature_lens
|
||||
pad_mask = (torch.arange(0, max_audio_length,
|
||||
device=device).expand(padding_length.size(0),
|
||||
-1)
|
||||
< padding_length.unsqueeze(1))
|
||||
pad_mask = torch.arange(0, max_audio_length, device=device).expand(
|
||||
padding_length.size(0), -1
|
||||
) < padding_length.unsqueeze(1)
|
||||
pad_mask = pad_mask.unsqueeze(1)
|
||||
pad_mask = pad_mask & enc_streaming_mask
|
||||
return pad_mask
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, xs_pad: torch.Tensor,
|
||||
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def forward(
|
||||
self, xs_pad: torch.Tensor, masks: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Conformer Forward function
|
||||
|
||||
Args:
|
||||
@@ -905,11 +947,12 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
"""
|
||||
xs_pad = self.encoder_embedding(xs_pad)
|
||||
input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(
|
||||
xs_pad, masks)
|
||||
xs_pad, masks
|
||||
)
|
||||
|
||||
unfolded = False
|
||||
ori_bz, seq_len, D = input_tensor.shape
|
||||
max_seq_len = 500 #maximum position for absolute positional encoding
|
||||
max_seq_len = 500 # maximum position for absolute positional encoding
|
||||
if seq_len > max_seq_len:
|
||||
# audio sequence is longer than max_seq_len, unfold it into chunks
|
||||
# of max_seq_len
|
||||
@@ -921,26 +964,29 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
else:
|
||||
chunk_pad_size = 0
|
||||
if chunk_pad_size > 0:
|
||||
input_tensor_pad = F.pad(input_tensor,
|
||||
(0, 0, 0, chunk_pad_size), "constant",
|
||||
0)
|
||||
input_tensor_pad = F.pad(
|
||||
input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0
|
||||
)
|
||||
input_tensor = input_tensor_pad.to(input_tensor.device)
|
||||
input_tensor = unfold_tensor(input_tensor, max_seq_len)
|
||||
if masks is not None:
|
||||
# revise hs_mask here because the previous calculated hs_mask
|
||||
# did not consider extra pad
|
||||
subsampled_pad_mask = masks.squeeze(
|
||||
1) # [bz, subsampled_unmask_seq_len]
|
||||
1
|
||||
) # [bz, subsampled_unmask_seq_len]
|
||||
extra_padded_subsamlped_pad_mask = F.pad(
|
||||
subsampled_pad_mask, (0, chunk_pad_size), "constant",
|
||||
False) # extra padding to the pad mask
|
||||
extra_padded_subsamlped_pad_mask = \
|
||||
subsampled_pad_mask, (0, chunk_pad_size), "constant", False
|
||||
) # extra padding to the pad mask
|
||||
extra_padded_subsamlped_pad_mask = (
|
||||
extra_padded_subsamlped_pad_mask.unsqueeze(-1).float()
|
||||
)
|
||||
masks_unfold = unfold_tensor(
|
||||
extra_padded_subsamlped_pad_mask, max_seq_len
|
||||
) # unfold the pad mask like we did to the input tensor
|
||||
masks_unfold = masks_unfold.squeeze(
|
||||
-1).bool() # unfold op does not support bool tensor
|
||||
-1
|
||||
).bool() # unfold op does not support bool tensor
|
||||
else:
|
||||
masks_unfold = None
|
||||
hs_mask = self.calculate_hs_mask(
|
||||
@@ -949,15 +995,14 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
|
||||
# layer_emb = None
|
||||
|
||||
relative_attention_bias = self.init_relative_attention_bias(
|
||||
input_tensor)
|
||||
relative_attention_bias = self.init_relative_attention_bias(input_tensor)
|
||||
|
||||
_simplified_path = (self.extra_layer_output_idx == -1
|
||||
and relative_attention_bias is None)
|
||||
_simplified_path = (
|
||||
self.extra_layer_output_idx == -1 and relative_attention_bias is None
|
||||
)
|
||||
|
||||
if _simplified_path:
|
||||
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v,
|
||||
hs_mask)
|
||||
input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask)
|
||||
else:
|
||||
for i, layer in enumerate(self.encoders):
|
||||
input_tensor, _, _, _ = layer(
|
||||
@@ -997,28 +1042,32 @@ class WindowQformer(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.decoders = nn.ModuleList([
|
||||
nn.TransformerDecoderLayer(
|
||||
d_model=attention_dim,
|
||||
nhead=attention_heads,
|
||||
dim_feedforward=linear_units,
|
||||
dropout=dropout_rate,
|
||||
activation="relu",
|
||||
batch_first=True,
|
||||
norm_first=normalize_before, # TODO need to verify
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
self.decoders = nn.ModuleList(
|
||||
[
|
||||
nn.TransformerDecoderLayer(
|
||||
d_model=attention_dim,
|
||||
nhead=attention_heads,
|
||||
dim_feedforward=linear_units,
|
||||
dropout=dropout_rate,
|
||||
activation="relu",
|
||||
batch_first=True,
|
||||
norm_first=normalize_before, # TODO need to verify
|
||||
)
|
||||
for _ in range(num_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.queries = nn.Parameter(torch.zeros(1, num_queries, attention_dim))
|
||||
self.after_norm = (nn.LayerNorm(attention_dim, eps=1e-12)
|
||||
if normalize_before else None)
|
||||
self.after_norm = (
|
||||
nn.LayerNorm(attention_dim, eps=1e-12) if normalize_before else None
|
||||
)
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
audio_embed: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
embed_len: Optional[int] = None
|
||||
self,
|
||||
audio_embed: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
embed_len: Optional[int] = None,
|
||||
) -> tuple[torch.Tensor, Optional[int]]:
|
||||
"""forward decoder"""
|
||||
# audio_embed: N x T x D => N x D x T
|
||||
@@ -1027,8 +1076,9 @@ class WindowQformer(nn.Module):
|
||||
# audio_embed: N x D x 1 x T => N x DK x T'
|
||||
padding = audio_embed.shape[-1] % self.window_size
|
||||
if padding > 0:
|
||||
audio_embed = F.pad(audio_embed, (0, self.window_size - padding),
|
||||
"constant", 0)
|
||||
audio_embed = F.pad(
|
||||
audio_embed, (0, self.window_size - padding), "constant", 0
|
||||
)
|
||||
|
||||
embed_chunk = F.unfold(
|
||||
audio_embed[..., None, :],
|
||||
@@ -1045,10 +1095,7 @@ class WindowQformer(nn.Module):
|
||||
# NT' x 1 x D
|
||||
q = self.queries.expand(bsz * slen, -1, -1)
|
||||
for layer in self.decoders:
|
||||
q = layer(tgt=q,
|
||||
memory=embed_chunk,
|
||||
tgt_mask=None,
|
||||
memory_mask=mask)
|
||||
q = layer(tgt=q, memory=embed_chunk, tgt_mask=None, memory_mask=mask)
|
||||
|
||||
if self.after_norm is not None:
|
||||
q = self.after_norm(q)
|
||||
@@ -1068,8 +1115,7 @@ class AudioEmbedding(nn.Module):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# n_embed or hidden_size for text LM
|
||||
hidden_size = (config.n_embd
|
||||
if hasattr(config, "n_embd") else config.hidden_size)
|
||||
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
|
||||
|
||||
# self.wte = nn.Embedding(config.vocab_size, hidden_size)
|
||||
|
||||
@@ -1078,8 +1124,10 @@ class AudioEmbedding(nn.Module):
|
||||
)
|
||||
self.layer_idx = -2
|
||||
|
||||
if (isinstance(config.audio_processor, dict)
|
||||
and config.audio_processor.get("name", None) == "cascades"):
|
||||
if (
|
||||
isinstance(config.audio_processor, dict)
|
||||
and config.audio_processor.get("name", None) == "cascades"
|
||||
):
|
||||
encoder_config = config.audio_processor.get("config", None)
|
||||
assert encoder_config is not None
|
||||
self.encoder = ConformerEncoder(**encoder_config)
|
||||
@@ -1089,13 +1137,11 @@ class AudioEmbedding(nn.Module):
|
||||
else:
|
||||
raise NotImplementedError("")
|
||||
|
||||
assert (audio_dim_out
|
||||
is not None), "Remember to set values for audio_dim_out"
|
||||
assert audio_dim_out is not None, "Remember to set values for audio_dim_out"
|
||||
self.audio_dim_out = audio_dim_out
|
||||
self.audio_dim_in = n_mels
|
||||
|
||||
self.freeze_audio_processor = kwargs.get("freeze_audio_processor",
|
||||
False)
|
||||
self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False)
|
||||
|
||||
self.downsample_rate = kwargs.get("downsample_rate", 1)
|
||||
|
||||
@@ -1107,8 +1153,9 @@ class AudioEmbedding(nn.Module):
|
||||
self.qformer = None
|
||||
|
||||
if kwargs.get("use_conv_downsample", False):
|
||||
assert (self.qformer is None
|
||||
), "don't support use qformer and conv downsample together"
|
||||
assert self.qformer is None, (
|
||||
"don't support use qformer and conv downsample together"
|
||||
)
|
||||
nemo_conv_settings = kwargs.get("nemo_conv_settings", {})
|
||||
default_nemo_conv_settings = {
|
||||
"subsampling": "dw_striding",
|
||||
@@ -1124,11 +1171,13 @@ class AudioEmbedding(nn.Module):
|
||||
if nemo_conv_settings:
|
||||
default_nemo_conv_settings.update(nemo_conv_settings)
|
||||
for i in ["subsampling_factor", "feat_in", "feat_out"]:
|
||||
assert (
|
||||
i not in nemo_conv_settings
|
||||
), "{i} should be specified outside of the NeMo dictionary"
|
||||
assert i not in nemo_conv_settings, (
|
||||
"{i} should be specified outside of the NeMo dictionary"
|
||||
)
|
||||
|
||||
self.conv_ds = NemoConvSubsampling(**default_nemo_conv_settings, )
|
||||
self.conv_ds = NemoConvSubsampling(
|
||||
**default_nemo_conv_settings,
|
||||
)
|
||||
else:
|
||||
self.conv_ds = None
|
||||
|
||||
@@ -1140,30 +1189,26 @@ class AudioEmbedding(nn.Module):
|
||||
# (do not use image_projection and image_proj_norm)
|
||||
dim_projection = hidden_size
|
||||
depth = 2
|
||||
self.linear_downsample_rate = (1 if (self.qformer or self.conv_ds)
|
||||
else self.downsample_rate)
|
||||
self.linear_downsample_rate = (
|
||||
1 if (self.qformer or self.conv_ds) else self.downsample_rate
|
||||
)
|
||||
layers = [
|
||||
nn.Linear(audio_dim_out * self.linear_downsample_rate,
|
||||
dim_projection)
|
||||
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
|
||||
]
|
||||
for _ in range(1, depth):
|
||||
layers.extend(
|
||||
[nn.GELU(),
|
||||
nn.Linear(dim_projection, dim_projection)])
|
||||
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
|
||||
self.audio_projection = nn.Sequential(*layers)
|
||||
# NOTE vision-speech tasks use a separate projection layer
|
||||
layers = [
|
||||
nn.Linear(audio_dim_out * self.linear_downsample_rate,
|
||||
dim_projection)
|
||||
nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)
|
||||
]
|
||||
for _ in range(1, depth):
|
||||
layers.extend(
|
||||
[nn.GELU(),
|
||||
nn.Linear(dim_projection, dim_projection)])
|
||||
layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)])
|
||||
self.audio_projection_for_vision = nn.Sequential(*layers)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"projection_cls = {projection_cls}, not implemented")
|
||||
f"projection_cls = {projection_cls}, not implemented"
|
||||
)
|
||||
|
||||
# TODO: audio sequence compression - Qformer
|
||||
self.vocab_size = config.vocab_size
|
||||
@@ -1188,11 +1233,9 @@ class AudioEmbedding(nn.Module):
|
||||
"""
|
||||
if self.freeze_audio_processor:
|
||||
with torch.no_grad():
|
||||
audio_features, masks = self.encoder(input_embeds,
|
||||
audio_attention_mask)
|
||||
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
|
||||
else:
|
||||
audio_features, masks = self.encoder(input_embeds,
|
||||
audio_attention_mask)
|
||||
audio_features, masks = self.encoder(input_embeds, audio_attention_mask)
|
||||
|
||||
if self.qformer is not None:
|
||||
audio_features, _ = self.qformer(audio_features, mask=None)
|
||||
@@ -1221,14 +1264,13 @@ class AudioEmbedding(nn.Module):
|
||||
feat_dim * self.linear_downsample_rate,
|
||||
)
|
||||
|
||||
if audio_projection_mode == 'speech':
|
||||
if audio_projection_mode == "speech":
|
||||
audio_set_tensor = self.audio_projection(audio_features)
|
||||
elif audio_projection_mode == 'vision':
|
||||
elif audio_projection_mode == "vision":
|
||||
audio_set_tensor = self.audio_projection_for_vision(audio_features)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"audio_projection_mode = {audio_projection_mode} not "\
|
||||
"implemented"
|
||||
f"audio_projection_mode = {audio_projection_mode} not implemented"
|
||||
)
|
||||
|
||||
return audio_set_tensor
|
||||
@@ -1242,7 +1284,7 @@ class AudioEmbedding(nn.Module):
|
||||
"""
|
||||
arguments:
|
||||
audio_features: audio features (T, D)
|
||||
|
||||
|
||||
returns:
|
||||
audio_embeds: audio embeddings (num_audio_tokens, hidden_dim)
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user