[GDN] Eliminate GPU->CPU sync in prepare_chunk_indices during prefill (#38361)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
This commit is contained in:
Artem Perevedentsev
2026-04-03 16:38:02 +03:00
committed by GitHub
parent bf8b022e60
commit cb10b7e80b
10 changed files with 116 additions and 44 deletions

View File

@@ -16,7 +16,7 @@ from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum
from .l2norm import l2norm_fwd
from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard
from .utils import FLA_CHUNK_SIZE, SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd
@@ -30,13 +30,24 @@ def chunk_gated_delta_rule_fwd(
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
g = chunk_local_cumsum(
g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices
)
# obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd(
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32
k=k,
beta=beta,
g=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
output_dtype=torch.float32,
)
A = solve_tril(
A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, output_dtype=k.dtype
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd(
k=k,
v=v,
@@ -44,6 +55,7 @@ def chunk_gated_delta_rule_fwd(
A=A,
g_cumsum=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k,
@@ -53,6 +65,8 @@ def chunk_gated_delta_rule_fwd(
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
)
o = chunk_fwd_o(
q=q,
@@ -62,6 +76,7 @@ def chunk_gated_delta_rule_fwd(
g=g,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
)
if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None
@@ -84,6 +99,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -100,6 +117,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
)
ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
@@ -117,6 +136,8 @@ def chunk_gated_delta_rule(
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
r"""
@@ -206,6 +227,8 @@ def chunk_gated_delta_rule(
initial_state,
output_final_state,
cu_seqlens,
chunk_indices,
chunk_offsets,
use_qk_l2norm_in_kernel,
)
return o, final_state

View File

@@ -14,7 +14,7 @@ from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp
from .utils import use_cuda_graph
from .utils import FLA_CHUNK_SIZE, use_cuda_graph
NUM_WARPS = [2, 4, 8, 16]
@@ -286,9 +286,11 @@ def chunk_gated_delta_rule_fwd_h(
gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
chunk_size: int = FLA_CHUNK_SIZE,
save_new_value: bool = True,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H.
@@ -296,20 +298,15 @@ def chunk_gated_delta_rule_fwd_h(
H = u.shape[-2]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
# N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else:
N, NT, chunk_offsets = (
len(cu_seqlens) - 1,
len(chunk_indices),
prepare_chunk_offsets(cu_seqlens, BT),
)
N, NT = len(cu_seqlens) - 1, len(chunk_indices)
if chunk_offsets is None:
chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, V, K)

View File

@@ -146,14 +146,14 @@ def chunk_fwd_o(
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_size: int = FLA_CHUNK_SIZE,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None:
scale = k.shape[-1] ** -0.5

View File

@@ -14,6 +14,7 @@ from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices
from .op import exp
from .utils import FLA_CHUNK_SIZE
@triton.heuristics(
@@ -103,7 +104,8 @@ def chunk_scaled_dot_kkt_fwd(
g: torch.Tensor | None = None,
beta: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
chunk_indices: torch.Tensor | None = None,
chunk_size: int = FLA_CHUNK_SIZE,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
r"""
@@ -119,6 +121,9 @@ def chunk_scaled_dot_kkt_fwd(
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor.
Default: None
chunk_indices (torch.Tensor):
Pre-computed chunk indices. If None and cu_seqlens is provided,
computed internally. Default: None
chunk_size (int):
The chunk size. Default: 64.
output_dtype (torch.dtype):
@@ -132,9 +137,8 @@ def chunk_scaled_dot_kkt_fwd(
B, T, Hg, K = k.shape
H = beta.shape[-1]
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)

View File

@@ -162,6 +162,7 @@ def chunk_local_cumsum_scalar(
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
@@ -172,10 +173,9 @@ def chunk_local_cumsum_scalar(
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H)
@@ -199,6 +199,7 @@ def chunk_local_cumsum_vector(
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
@@ -206,16 +207,13 @@ def chunk_local_cumsum_vector(
B, H, T, S = g.shape
else:
B, T, H, S = g.shape
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2"
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
BT = chunk_size
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
@@ -247,6 +245,7 @@ def chunk_local_cumsum(
chunk_size: int,
reverse: bool = False,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: torch.dtype | None = torch.float,
**kwargs,
@@ -257,11 +256,23 @@ def chunk_local_cumsum(
)
if len(g.shape) == 3:
return chunk_local_cumsum_scalar(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
g,
chunk_size,
reverse,
cu_seqlens,
chunk_indices,
head_first,
output_dtype,
)
elif len(g.shape) == 4:
return chunk_local_cumsum_vector(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype
g,
chunk_size,
reverse,
cu_seqlens,
chunk_indices,
head_first,
output_dtype,
)
else:
raise ValueError(

View File

@@ -23,7 +23,7 @@ from .index import prepare_chunk_indices
from .l2norm import l2norm_fwd
from .op import exp, log
from .solve_tril import solve_tril
from .utils import is_amd
from .utils import FLA_CHUNK_SIZE, is_amd
BT_LIST_AUTOTUNE = [32, 64, 128]
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
@@ -721,7 +721,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
beta: torch.Tensor | None = None,
scale: float | None = None,
cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64,
chunk_size: int = FLA_CHUNK_SIZE,
output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""
@@ -1178,7 +1178,7 @@ def chunk_kda_fwd(
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
):
chunk_size = 64
chunk_size = FLA_CHUNK_SIZE
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
# the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput
@@ -1189,6 +1189,7 @@ def chunk_kda_fwd(
beta=beta,
scale=scale,
cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
output_dtype=torch.float32,
)
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)

View File

@@ -507,6 +507,7 @@ def merge_16x16_to_64x64_inverse_kernel(
def solve_tril(
A: torch.Tensor,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""
@@ -518,6 +519,8 @@ def solve_tril(
[B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`.
chunk_indices (torch.Tensor):
Pre-computed chunk indices. Default: `None`.
output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype.
@@ -529,9 +532,8 @@ def solve_tril(
output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype)

View File

@@ -123,14 +123,14 @@ def recompute_w_u_fwd(
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: torch.Tensor | None,
chunk_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]
BT = A.shape[-1]
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64
BV = 64

View File

@@ -163,6 +163,8 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fi_chunk_gated_delta_rule(
@@ -187,6 +189,8 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True,
):
return fla_chunk_gated_delta_rule(
@@ -198,6 +202,8 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state=initial_state,
output_final_state=output_final_state,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
@@ -959,6 +965,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
initial_state=initial_state,
output_final_state=True,
cu_seqlens=non_spec_query_start_loc,
chunk_indices=attn_metadata.chunk_indices,
chunk_offsets=attn_metadata.chunk_offsets,
use_qk_l2norm_in_kernel=False,
)
# Init cache