[GDN] Eliminate GPU->CPU sync in prepare_chunk_indices during prefill (#38361)
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com> Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
bf8b022e60
commit
cb10b7e80b
@@ -16,7 +16,7 @@ from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
|
||||
from .cumsum import chunk_local_cumsum
|
||||
from .l2norm import l2norm_fwd
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import SUPPRESS_LEVEL, input_guard
|
||||
from .utils import FLA_CHUNK_SIZE, SUPPRESS_LEVEL, input_guard
|
||||
from .wy_fast import recompute_w_u_fwd
|
||||
|
||||
|
||||
@@ -30,13 +30,24 @@ def chunk_gated_delta_rule_fwd(
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
):
|
||||
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
|
||||
g = chunk_local_cumsum(
|
||||
g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices
|
||||
)
|
||||
# obtain WY representation. u is actually the new v.
|
||||
A = chunk_scaled_dot_kkt_fwd(
|
||||
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
|
||||
k=k,
|
||||
beta=beta,
|
||||
g=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
A = solve_tril(
|
||||
A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, output_dtype=k.dtype
|
||||
)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
w, u = recompute_w_u_fwd(
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -44,6 +55,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
A=A,
|
||||
g_cumsum=g,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
)
|
||||
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
|
||||
k=k,
|
||||
@@ -53,6 +65,8 @@ def chunk_gated_delta_rule_fwd(
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
)
|
||||
o = chunk_fwd_o(
|
||||
q=q,
|
||||
@@ -62,6 +76,7 @@ def chunk_gated_delta_rule_fwd(
|
||||
g=g,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
)
|
||||
if SUPPRESS_LEVEL < 3:
|
||||
return g, o, A, final_state, None, None, None
|
||||
@@ -84,6 +99,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
if use_qk_l2norm_in_kernel:
|
||||
@@ -100,6 +117,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
)
|
||||
ctx.scale = scale
|
||||
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
|
||||
@@ -117,6 +136,8 @@ def chunk_gated_delta_rule(
|
||||
initial_state: torch.Tensor = None,
|
||||
output_final_state: bool = False,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = False,
|
||||
):
|
||||
r"""
|
||||
@@ -206,6 +227,8 @@ def chunk_gated_delta_rule(
|
||||
initial_state,
|
||||
output_final_state,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
chunk_offsets,
|
||||
use_qk_l2norm_in_kernel,
|
||||
)
|
||||
return o, final_state
|
||||
|
||||
@@ -14,7 +14,7 @@ from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices, prepare_chunk_offsets
|
||||
from .op import exp
|
||||
from .utils import use_cuda_graph
|
||||
from .utils import FLA_CHUNK_SIZE, use_cuda_graph
|
||||
|
||||
NUM_WARPS = [2, 4, 8, 16]
|
||||
|
||||
@@ -286,9 +286,11 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
gk: torch.Tensor | None = None,
|
||||
initial_state: torch.Tensor | None = None,
|
||||
output_final_state: bool = False,
|
||||
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
|
||||
chunk_size: int = FLA_CHUNK_SIZE,
|
||||
save_new_value: bool = True,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# This kernel is slightly different from fla to support Q/K with different head numbers.
|
||||
# In fla, Q/K always have the same head number, so Hg is always equal to H.
|
||||
@@ -296,20 +298,15 @@ def chunk_gated_delta_rule_fwd_h(
|
||||
H = u.shape[-2]
|
||||
BT = chunk_size
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
# N: the actual number of sequences in the batch with either equal or variable lengths
|
||||
if cu_seqlens is None:
|
||||
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
|
||||
else:
|
||||
N, NT, chunk_offsets = (
|
||||
len(cu_seqlens) - 1,
|
||||
len(chunk_indices),
|
||||
prepare_chunk_offsets(cu_seqlens, BT),
|
||||
)
|
||||
N, NT = len(cu_seqlens) - 1, len(chunk_indices)
|
||||
if chunk_offsets is None:
|
||||
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
|
||||
assert K <= 256, "current kernel does not support head dimension larger than 256."
|
||||
|
||||
h = k.new_empty(B, NT, H, V, K)
|
||||
|
||||
@@ -146,14 +146,14 @@ def chunk_fwd_o(
|
||||
g: torch.Tensor | None = None, # cumsum of log decay
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_size: int = FLA_CHUNK_SIZE,
|
||||
) -> torch.Tensor:
|
||||
B, T, Hg, K, V = *q.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
if scale is None:
|
||||
scale = k.shape[-1] ** -0.5
|
||||
|
||||
@@ -14,6 +14,7 @@ from vllm.triton_utils import tl, triton
|
||||
|
||||
from .index import prepare_chunk_indices
|
||||
from .op import exp
|
||||
from .utils import FLA_CHUNK_SIZE
|
||||
|
||||
|
||||
@triton.heuristics(
|
||||
@@ -103,7 +104,8 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
g: torch.Tensor | None = None,
|
||||
beta: torch.Tensor | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_size: int = FLA_CHUNK_SIZE,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -119,6 +121,9 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor.
|
||||
Default: None
|
||||
chunk_indices (torch.Tensor):
|
||||
Pre-computed chunk indices. If None and cu_seqlens is provided,
|
||||
computed internally. Default: None
|
||||
chunk_size (int):
|
||||
The chunk size. Default: 64.
|
||||
output_dtype (torch.dtype):
|
||||
@@ -132,9 +137,8 @@ def chunk_scaled_dot_kkt_fwd(
|
||||
B, T, Hg, K = k.shape
|
||||
H = beta.shape[-1]
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
|
||||
|
||||
@@ -162,6 +162,7 @@ def chunk_local_cumsum_scalar(
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
) -> torch.Tensor:
|
||||
@@ -172,10 +173,9 @@ def chunk_local_cumsum_scalar(
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
||||
"chunk_size must be a power of 2"
|
||||
)
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
grid = (NT, B * H)
|
||||
@@ -199,6 +199,7 @@ def chunk_local_cumsum_vector(
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
) -> torch.Tensor:
|
||||
@@ -206,16 +207,13 @@ def chunk_local_cumsum_vector(
|
||||
B, H, T, S = g.shape
|
||||
else:
|
||||
B, T, H, S = g.shape
|
||||
BT = chunk_size
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
if cu_seqlens is not None
|
||||
else None
|
||||
)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
|
||||
"chunk_size must be a power of 2"
|
||||
)
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
|
||||
BT = chunk_size
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
|
||||
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
|
||||
|
||||
@@ -247,6 +245,7 @@ def chunk_local_cumsum(
|
||||
chunk_size: int,
|
||||
reverse: bool = False,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
head_first: bool = False,
|
||||
output_dtype: torch.dtype | None = torch.float,
|
||||
**kwargs,
|
||||
@@ -257,11 +256,23 @@ def chunk_local_cumsum(
|
||||
)
|
||||
if len(g.shape) == 3:
|
||||
return chunk_local_cumsum_scalar(
|
||||
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
|
||||
g,
|
||||
chunk_size,
|
||||
reverse,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
head_first,
|
||||
output_dtype,
|
||||
)
|
||||
elif len(g.shape) == 4:
|
||||
return chunk_local_cumsum_vector(
|
||||
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
|
||||
g,
|
||||
chunk_size,
|
||||
reverse,
|
||||
cu_seqlens,
|
||||
chunk_indices,
|
||||
head_first,
|
||||
output_dtype,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -23,7 +23,7 @@ from .index import prepare_chunk_indices
|
||||
from .l2norm import l2norm_fwd
|
||||
from .op import exp, log
|
||||
from .solve_tril import solve_tril
|
||||
from .utils import is_amd
|
||||
from .utils import FLA_CHUNK_SIZE, is_amd
|
||||
|
||||
BT_LIST_AUTOTUNE = [32, 64, 128]
|
||||
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
|
||||
@@ -721,7 +721,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
|
||||
beta: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_size: int = 64,
|
||||
chunk_size: int = FLA_CHUNK_SIZE,
|
||||
output_dtype: torch.dtype = torch.float32,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
@@ -1178,7 +1178,7 @@ def chunk_kda_fwd(
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
):
|
||||
chunk_size = 64
|
||||
chunk_size = FLA_CHUNK_SIZE
|
||||
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
|
||||
# the intra Aqk is kept in fp32
|
||||
# the computation has very marginal effect on the entire throughput
|
||||
@@ -1189,6 +1189,7 @@ def chunk_kda_fwd(
|
||||
beta=beta,
|
||||
scale=scale,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_size=chunk_size,
|
||||
output_dtype=torch.float32,
|
||||
)
|
||||
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
|
||||
|
||||
@@ -507,6 +507,7 @@ def merge_16x16_to_64x64_inverse_kernel(
|
||||
def solve_tril(
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
output_dtype: torch.dtype = torch.float,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -518,6 +519,8 @@ def solve_tril(
|
||||
[B, T, H, BT], where BT should only be 16, 32, or 64.
|
||||
cu_seqlens (torch.Tensor):
|
||||
The cumulative sequence lengths of the input tensor. Default: `None`.
|
||||
chunk_indices (torch.Tensor):
|
||||
Pre-computed chunk indices. Default: `None`.
|
||||
output_dtype (torch.dtype):
|
||||
The dtype of the output tensor. Default: `torch.float`.
|
||||
If `None`, the output dtype will be the same as the input dtype.
|
||||
@@ -529,9 +532,8 @@ def solve_tril(
|
||||
output_dtype = A.dtype if output_dtype is None else output_dtype
|
||||
|
||||
B, T, H, BT = A.shape
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
|
||||
|
||||
Ai = torch.zeros_like(A, dtype=output_dtype)
|
||||
|
||||
@@ -123,14 +123,14 @@ def recompute_w_u_fwd(
|
||||
g_cumsum: torch.Tensor,
|
||||
A: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor | None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
B, T, Hg, K, V = *k.shape, v.shape[-1]
|
||||
H = v.shape[-2]
|
||||
BT = A.shape[-1]
|
||||
|
||||
chunk_indices = (
|
||||
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
|
||||
)
|
||||
if chunk_indices is None and cu_seqlens is not None:
|
||||
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
|
||||
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
|
||||
BK = 64
|
||||
BV = 64
|
||||
|
||||
@@ -163,6 +163,8 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fi_chunk_gated_delta_rule(
|
||||
@@ -187,6 +189,8 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
initial_state: torch.Tensor,
|
||||
output_final_state: bool,
|
||||
cu_seqlens: torch.Tensor | None = None,
|
||||
chunk_indices: torch.Tensor | None = None,
|
||||
chunk_offsets: torch.Tensor | None = None,
|
||||
use_qk_l2norm_in_kernel: bool = True,
|
||||
):
|
||||
return fla_chunk_gated_delta_rule(
|
||||
@@ -198,6 +202,8 @@ class ChunkGatedDeltaRule(CustomOp):
|
||||
initial_state=initial_state,
|
||||
output_final_state=output_final_state,
|
||||
cu_seqlens=cu_seqlens,
|
||||
chunk_indices=chunk_indices,
|
||||
chunk_offsets=chunk_offsets,
|
||||
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
|
||||
)
|
||||
|
||||
@@ -959,6 +965,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
|
||||
initial_state=initial_state,
|
||||
output_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc,
|
||||
chunk_indices=attn_metadata.chunk_indices,
|
||||
chunk_offsets=attn_metadata.chunk_offsets,
|
||||
use_qk_l2norm_in_kernel=False,
|
||||
)
|
||||
# Init cache
|
||||
|
||||
Reference in New Issue
Block a user