[Docs] Fix warnings in mkdocs build (continued) (#24092)
Signed-off-by: Zerohertz <ohg3417@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Wentao Ye <44945378+yewentao256@users.noreply.github.com>
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#!/usr/bin/env python3
|
||||
import abc
|
||||
import math
|
||||
from typing import Literal, Optional
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -131,31 +131,31 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model=512,
|
||||
ext_pw_out_channel=0,
|
||||
depthwise_seperable_out_channel=256,
|
||||
depthwise_multiplier=1,
|
||||
n_head=4,
|
||||
d_ffn=2048,
|
||||
ext_pw_kernel_size=1,
|
||||
kernel_size=3,
|
||||
dropout_rate=0.1,
|
||||
causal=False,
|
||||
batch_norm=False,
|
||||
activation="relu",
|
||||
chunk_se=0,
|
||||
chunk_size=18,
|
||||
conv_activation="relu",
|
||||
conv_glu_type="sigmoid",
|
||||
bias_in_glu=True,
|
||||
linear_glu_in_convm=False,
|
||||
attention_inner_dim=-1,
|
||||
attention_glu_type="swish",
|
||||
activation_checkpointing="",
|
||||
export=False,
|
||||
use_pt_scaled_dot_product_attention=False,
|
||||
d_model: int = 512,
|
||||
ext_pw_out_channel: int = 0,
|
||||
depthwise_seperable_out_channel: int = 256,
|
||||
depthwise_multiplier: int = 1,
|
||||
n_head: int = 4,
|
||||
d_ffn: int = 2048,
|
||||
ext_pw_kernel_size: int = 1,
|
||||
kernel_size: int = 3,
|
||||
dropout_rate: float = 0.1,
|
||||
causal: bool = False,
|
||||
batch_norm: bool = False,
|
||||
activation: str = "relu",
|
||||
chunk_se: int = 0,
|
||||
chunk_size: int = 18,
|
||||
conv_activation: str = "relu",
|
||||
conv_glu_type: str = "sigmoid",
|
||||
bias_in_glu: bool = True,
|
||||
linear_glu_in_convm: bool = False,
|
||||
attention_inner_dim: int = -1,
|
||||
attention_glu_type: str = "swish",
|
||||
activation_checkpointing: str = "",
|
||||
export: bool = False,
|
||||
use_pt_scaled_dot_product_attention: bool = False,
|
||||
attn_group_sizes: int = 1,
|
||||
):
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.feed_forward_in = FeedForward(
|
||||
@@ -209,24 +209,21 @@ class ConformerEncoderLayer(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
pos_k,
|
||||
pos_v,
|
||||
mask,
|
||||
x: torch.Tensor,
|
||||
pos_k: torch.Tensor,
|
||||
pos_v: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
relative_attention_bias: Optional[Tensor] = None,
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""ConformerEncoder forward.
|
||||
|
||||
Args:
|
||||
x: torch.Tensor
|
||||
input feature of shape (batch, max_time_in, size)
|
||||
pos_k: torch.Tensor
|
||||
positional key embedding.
|
||||
mask: torch.Tensor
|
||||
mask for x (batch, max_time_in)
|
||||
relative_attention_bias: Optional[torch.Tensor]
|
||||
bias added to attention logits w.r.t. relative positions
|
||||
(1, n_head, time1, time2)
|
||||
x: input feature of shape (batch, max_time_in, size)
|
||||
pos_k: positional key embedding.
|
||||
pos_v: positional value embedding.
|
||||
mask: mask for x (batch, max_time_in)
|
||||
relative_attention_bias: bias added to attention logits w.r.t.
|
||||
relative positions (1, n_head, time1, time2)
|
||||
"""
|
||||
x = x + 0.5 * self.feed_forward_in(x)
|
||||
norm_x = self.layer_norm_att(x)
|
||||
@@ -323,25 +320,25 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size,
|
||||
chunk_size,
|
||||
left_chunk,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
input_layer="nemo_conv",
|
||||
cnn_out=-1,
|
||||
cnn_layer_norm=False,
|
||||
time_reduction=4,
|
||||
dropout_rate=0.0,
|
||||
padding_idx=-1,
|
||||
relative_attention_bias_args=None,
|
||||
positional_dropout_rate=0.0,
|
||||
nemo_conv_settings=None,
|
||||
input_size: int,
|
||||
chunk_size: Union[int, list[int]],
|
||||
left_chunk: Union[int, list[int]],
|
||||
attention_dim: int = 256,
|
||||
attention_heads: int = 4,
|
||||
input_layer: str = "nemo_conv",
|
||||
cnn_out: int = -1,
|
||||
cnn_layer_norm: bool = False,
|
||||
time_reduction: int = 4,
|
||||
dropout_rate: float = 0.0,
|
||||
padding_idx: int = -1,
|
||||
relative_attention_bias_args: Optional[dict[str, Any]] = None,
|
||||
positional_dropout_rate: float = 0.0,
|
||||
nemo_conv_settings: Optional[dict[str, Any]] = None,
|
||||
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
||||
True] = "none",
|
||||
attention_group_size=1,
|
||||
encoder_embedding_config=None,
|
||||
):
|
||||
attention_group_size: int = 1,
|
||||
encoder_embedding_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.input_layer = input_layer
|
||||
@@ -399,7 +396,10 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
self.encoder_embedding = MeanVarianceNormLayer(
|
||||
self.encoder_embedding_config["input_size"])
|
||||
|
||||
def compute_lens_change(self, feature_lens):
|
||||
def compute_lens_change(
|
||||
self,
|
||||
feature_lens: Union[int,
|
||||
torch.Tensor]) -> Union[int, torch.Tensor]:
|
||||
"""feature_lens: int
|
||||
return updated feature lens.
|
||||
|
||||
@@ -433,10 +433,14 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
return ceil_func(feature_lens / self.time_reduction)
|
||||
|
||||
@abc.abstractmethod
|
||||
def forward(self):
|
||||
def forward(self) -> Any:
|
||||
"""Abstract forward method implementation."""
|
||||
|
||||
def _chunk_size_selection(self, chunk_size=None, left_chunk=None):
|
||||
def _chunk_size_selection(
|
||||
self,
|
||||
chunk_size: Optional[Union[int, list[int]]] = None,
|
||||
left_chunk: Optional[Union[int,
|
||||
list[int]]] = None) -> tuple[int, int]:
|
||||
"""If chunk size is a list, we will randomly select a chunk size."""
|
||||
|
||||
if chunk_size is None:
|
||||
@@ -463,7 +467,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
|
||||
return chunk_size_train_eff, left_chunk_train_eff
|
||||
|
||||
def _get_embed_class(self, embed):
|
||||
def _get_embed_class(self, embed: nn.Module) -> nn.Module:
|
||||
# pylint: disable=protected-access
|
||||
is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper)
|
||||
is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel)
|
||||
@@ -474,13 +478,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
embed_class = embed.module
|
||||
return embed_class
|
||||
|
||||
def _forward_embeddings_core(self, input_tensor, masks):
|
||||
def _forward_embeddings_core(
|
||||
self, input_tensor: torch.Tensor,
|
||||
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
embed_class = self._get_embed_class(self.embed)
|
||||
assert isinstance(embed_class, NemoConvSubsampling)
|
||||
input_tensor, masks = self.embed(input_tensor, masks)
|
||||
return input_tensor, masks
|
||||
|
||||
def _position_embedding(self, input_tensor):
|
||||
def _position_embedding(
|
||||
self, input_tensor: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
pos_k = None
|
||||
pos_v = None
|
||||
if self.relative_attention_bias_layer is None:
|
||||
@@ -488,7 +496,9 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
input_tensor) # default to add abs sinusoid embedding
|
||||
return pos_k, pos_v
|
||||
|
||||
def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk):
|
||||
def _streaming_mask(self, seq_len: int, batch_size: int,
|
||||
chunk_size: Union[int, list[int]],
|
||||
left_chunk: Union[int, list[int]]) -> torch.Tensor:
|
||||
chunk_size_train_eff, left_chunk_train_eff = \
|
||||
self._chunk_size_selection(chunk_size, left_chunk)
|
||||
|
||||
@@ -502,11 +512,17 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
[batch_size, -1, -1]))
|
||||
return enc_streaming_mask
|
||||
|
||||
def forward_embeddings(self,
|
||||
xs_pad,
|
||||
masks,
|
||||
chunk_size_nc=None,
|
||||
left_chunk_nc=None):
|
||||
def forward_embeddings(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
chunk_size_nc: Optional[Union[int, list[int]]] = None,
|
||||
left_chunk_nc: Optional[Union[int, list[int]]] = None
|
||||
) -> Union[tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor],
|
||||
tuple[torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor], torch.Tensor, torch.Tensor,
|
||||
torch.Tensor]]:
|
||||
"""Forwarding the inputs through the top embedding layers
|
||||
|
||||
Args:
|
||||
@@ -569,7 +585,7 @@ class TransformerEncoderBase(abc.ABC, nn.Module):
|
||||
return input_tensor, pos_k, pos_v, hs_mask, masks
|
||||
return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc
|
||||
|
||||
def get_offset(self):
|
||||
def get_offset(self) -> int:
|
||||
"""Returns offset used when retaining inputs for decoding.
|
||||
|
||||
This is essentially, how many additional frames have to be added to
|
||||
@@ -605,8 +621,6 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
Some examples for the 2 cases:
|
||||
left_chunk = 6
|
||||
left_chunk = [12, 9, 6, 3]
|
||||
left_chunk: int
|
||||
number of chunks used for masking in streaming mode.
|
||||
num_lang: int
|
||||
This parameter is used to store the number of languages in the
|
||||
lang_dict, only used for multiseed/multilingual models.
|
||||
@@ -751,46 +765,46 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
|
||||
def __init__( # pylint: disable-all
|
||||
self,
|
||||
input_size,
|
||||
chunk_size,
|
||||
left_chunk,
|
||||
num_lang=None,
|
||||
attention_dim=256,
|
||||
attention_heads=4,
|
||||
linear_units=2048,
|
||||
num_blocks=6,
|
||||
dropout_rate=0.1,
|
||||
input_layer="nemo_conv",
|
||||
causal=True,
|
||||
batch_norm=False,
|
||||
cnn_out=-1,
|
||||
cnn_layer_norm=False,
|
||||
ext_pw_out_channel=0,
|
||||
ext_pw_kernel_size=1,
|
||||
depthwise_seperable_out_channel=256,
|
||||
depthwise_multiplier=1,
|
||||
chunk_se=0,
|
||||
kernel_size=3,
|
||||
activation="relu",
|
||||
conv_activation="relu",
|
||||
conv_glu_type="sigmoid",
|
||||
bias_in_glu=True,
|
||||
linear_glu_in_convm=False,
|
||||
attention_glu_type="swish",
|
||||
export=False,
|
||||
extra_layer_output_idx=-1,
|
||||
extra_multi_layer_output_idxs=[], # noqa
|
||||
activation_checkpointing="",
|
||||
relative_attention_bias_args=None,
|
||||
time_reduction=4,
|
||||
use_pt_scaled_dot_product_attention=False,
|
||||
nemo_conv_settings=None,
|
||||
input_size: int,
|
||||
chunk_size: Union[int, list[int]],
|
||||
left_chunk: Union[int, list[int]],
|
||||
num_lang: Optional[int] = None,
|
||||
attention_dim: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
input_layer: str = "nemo_conv",
|
||||
causal: bool = True,
|
||||
batch_norm: bool = False,
|
||||
cnn_out: int = -1,
|
||||
cnn_layer_norm: bool = False,
|
||||
ext_pw_out_channel: int = 0,
|
||||
ext_pw_kernel_size: int = 1,
|
||||
depthwise_seperable_out_channel: int = 256,
|
||||
depthwise_multiplier: int = 1,
|
||||
chunk_se: int = 0,
|
||||
kernel_size: int = 3,
|
||||
activation: str = "relu",
|
||||
conv_activation: str = "relu",
|
||||
conv_glu_type: str = "sigmoid",
|
||||
bias_in_glu: bool = True,
|
||||
linear_glu_in_convm: bool = False,
|
||||
attention_glu_type: str = "swish",
|
||||
export: bool = False,
|
||||
extra_layer_output_idx: int = -1,
|
||||
extra_multi_layer_output_idxs: list[int] = [], # noqa
|
||||
activation_checkpointing: str = "",
|
||||
relative_attention_bias_args: Optional[dict[str, Any]] = None,
|
||||
time_reduction: int = 4,
|
||||
use_pt_scaled_dot_product_attention: bool = False,
|
||||
nemo_conv_settings: Optional[dict[str, Any]] = None,
|
||||
conv2d_extra_padding: Literal["feat", "feat_time", "none",
|
||||
True] = "none",
|
||||
replication_pad_for_subsample_embedding=False,
|
||||
attention_group_size=1,
|
||||
encoder_embedding_config=None,
|
||||
):
|
||||
replication_pad_for_subsample_embedding: bool = False,
|
||||
attention_group_size: int = 1,
|
||||
encoder_embedding_config: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
input_size,
|
||||
chunk_size,
|
||||
@@ -852,11 +866,13 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
# the device and the needed dtype:
|
||||
self.register_buffer("dev_type", torch.zeros(()), persistent=False)
|
||||
|
||||
def init_relative_attention_bias(self, input_tensor):
|
||||
def init_relative_attention_bias(
|
||||
self, input_tensor: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
if self.relative_attention_bias_layer:
|
||||
return self.relative_attention_bias_layer(input_tensor)
|
||||
|
||||
def calculate_hs_mask(self, xs_pad, device, mask):
|
||||
def calculate_hs_mask(self, xs_pad: torch.Tensor, device: torch.device,
|
||||
mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
max_audio_length = xs_pad.shape[1]
|
||||
batch_size = xs_pad.shape[0]
|
||||
enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size,
|
||||
@@ -877,7 +893,8 @@ class ConformerEncoder(TransformerEncoderBase):
|
||||
return pad_mask
|
||||
|
||||
@torch.jit.ignore
|
||||
def forward(self, xs_pad, masks):
|
||||
def forward(self, xs_pad: torch.Tensor,
|
||||
masks: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Conformer Forward function
|
||||
|
||||
Args:
|
||||
@@ -997,7 +1014,12 @@ class WindowQformer(nn.Module):
|
||||
if normalize_before else None)
|
||||
self.window_size = window_size
|
||||
|
||||
def forward(self, audio_embed, mask, embed_len=None):
|
||||
def forward(
|
||||
self,
|
||||
audio_embed: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
embed_len: Optional[int] = None
|
||||
) -> tuple[torch.Tensor, Optional[int]]:
|
||||
"""forward decoder"""
|
||||
# audio_embed: N x T x D => N x D x T
|
||||
|
||||
@@ -1042,7 +1064,7 @@ class WindowQformer(nn.Module):
|
||||
class AudioEmbedding(nn.Module):
|
||||
"""Image embedding."""
|
||||
|
||||
def __init__(self, config: PretrainedConfig, **kwargs) -> None:
|
||||
def __init__(self, config: PretrainedConfig, **kwargs: Any) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# n_embed or hidden_size for text LM
|
||||
@@ -1148,19 +1170,18 @@ class AudioEmbedding(nn.Module):
|
||||
self.input_embeds = None
|
||||
self.audio_embed_sizes = None
|
||||
|
||||
def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None:
|
||||
def set_audio_embeds(self, input_embeds: torch.Tensor) -> None:
|
||||
self.input_embeds = input_embeds
|
||||
|
||||
def set_audio_embed_sizes(self,
|
||||
audio_embed_sizes: torch.LongTensor) -> None:
|
||||
def set_audio_embed_sizes(self, audio_embed_sizes: torch.Tensor) -> None:
|
||||
self.audio_embed_sizes = audio_embed_sizes
|
||||
|
||||
def get_audio_features(
|
||||
self,
|
||||
input_embeds: torch.FloatTensor,
|
||||
audio_attention_mask: torch.Tensor = None,
|
||||
input_embeds: torch.Tensor,
|
||||
audio_attention_mask: Optional[torch.Tensor] = None,
|
||||
audio_projection_mode: str = "speech",
|
||||
) -> torch.FloatTensor:
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
arguments:
|
||||
input_embeds: audio features (B, T, D) B: num audios in a sequence
|
||||
@@ -1214,10 +1235,10 @@ class AudioEmbedding(nn.Module):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
audio_features: torch.FloatTensor,
|
||||
audio_attention_mask: torch.Tensor = None,
|
||||
audio_features: torch.Tensor,
|
||||
audio_attention_mask: Optional[torch.Tensor] = None,
|
||||
audio_projection_mode: str = "speech",
|
||||
) -> torch.FloatTensor:
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
arguments:
|
||||
audio_features: audio features (T, D)
|
||||
|
||||
Reference in New Issue
Block a user