[Docs] Fix warnings in mkdocs build (continued) (#24791)
Signed-off-by: Zerohertz <ohg3417@gmail.com>
This commit is contained in:
@@ -70,11 +70,15 @@ def multihead_attention(
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
) -> torch.Tensor:
|
||||
"""Multi-head attention using flash attention 2.
|
||||
|
||||
Args:
|
||||
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
||||
The first element should be 0 and the last element should be q.shape[0].
|
||||
@@ -123,8 +127,14 @@ def sdpa_attention(
|
||||
"""SDPA attention.
|
||||
|
||||
Args:
|
||||
q, k, v: tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
q: Query tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
k: Key tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
v: Value tensor of shape (batch_size, seqlen, num_heads, head_dim),
|
||||
or (tot_seqlens, num_heads, head_dim) if packing.
|
||||
q_cu_seqlens: Optional cumulative sequence lengths of q.
|
||||
k_cu_seqlens: Optional cumulative sequence lengths of k.
|
||||
"""
|
||||
seq_length = q.shape[0]
|
||||
attention_mask = torch.zeros([1, seq_length, seq_length],
|
||||
@@ -387,7 +397,7 @@ class MLP2(nn.Module):
|
||||
def __init__(self,
|
||||
dims: list[int],
|
||||
activation,
|
||||
bias=True,
|
||||
bias: bool = True,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False):
|
||||
super().__init__()
|
||||
|
||||
Reference in New Issue
Block a user