[GDN] Eliminate GPU->CPU sync in prepare_chunk_indices during prefill (#38361)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
Signed-off-by: Vadim Gimpelson <156319763+vadiklyutiy@users.noreply.github.com>
This commit is contained in:
Artem Perevedentsev
2026-04-03 16:38:02 +03:00
committed by GitHub
parent bf8b022e60
commit cb10b7e80b
10 changed files with 116 additions and 44 deletions

View File

@@ -16,7 +16,7 @@ from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from .cumsum import chunk_local_cumsum from .cumsum import chunk_local_cumsum
from .l2norm import l2norm_fwd from .l2norm import l2norm_fwd
from .solve_tril import solve_tril from .solve_tril import solve_tril
from .utils import SUPPRESS_LEVEL, input_guard from .utils import FLA_CHUNK_SIZE, SUPPRESS_LEVEL, input_guard
from .wy_fast import recompute_w_u_fwd from .wy_fast import recompute_w_u_fwd
@@ -30,13 +30,24 @@ def chunk_gated_delta_rule_fwd(
initial_state: torch.Tensor, initial_state: torch.Tensor,
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
): ):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) g = chunk_local_cumsum(
g, chunk_size=FLA_CHUNK_SIZE, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices
)
# obtain WY representation. u is actually the new v. # obtain WY representation. u is actually the new v.
A = chunk_scaled_dot_kkt_fwd( A = chunk_scaled_dot_kkt_fwd(
k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32 k=k,
beta=beta,
g=g,
cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
output_dtype=torch.float32,
)
A = solve_tril(
A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, output_dtype=k.dtype
) )
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
w, u = recompute_w_u_fwd( w, u = recompute_w_u_fwd(
k=k, k=k,
v=v, v=v,
@@ -44,6 +55,7 @@ def chunk_gated_delta_rule_fwd(
A=A, A=A,
g_cumsum=g, g_cumsum=g,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
) )
h, v_new, final_state = chunk_gated_delta_rule_fwd_h( h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
k=k, k=k,
@@ -53,6 +65,8 @@ def chunk_gated_delta_rule_fwd(
initial_state=initial_state, initial_state=initial_state,
output_final_state=output_final_state, output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
) )
o = chunk_fwd_o( o = chunk_fwd_o(
q=q, q=q,
@@ -62,6 +76,7 @@ def chunk_gated_delta_rule_fwd(
g=g, g=g,
scale=scale, scale=scale,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
) )
if SUPPRESS_LEVEL < 3: if SUPPRESS_LEVEL < 3:
return g, o, A, final_state, None, None, None return g, o, A, final_state, None, None, None
@@ -84,6 +99,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
initial_state: torch.Tensor, initial_state: torch.Tensor,
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False, use_qk_l2norm_in_kernel: bool = False,
): ):
if use_qk_l2norm_in_kernel: if use_qk_l2norm_in_kernel:
@@ -100,6 +117,8 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
initial_state=initial_state, initial_state=initial_state,
output_final_state=output_final_state, output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
) )
ctx.scale = scale ctx.scale = scale
ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
@@ -117,6 +136,8 @@ def chunk_gated_delta_rule(
initial_state: torch.Tensor = None, initial_state: torch.Tensor = None,
output_final_state: bool = False, output_final_state: bool = False,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False, use_qk_l2norm_in_kernel: bool = False,
): ):
r""" r"""
@@ -206,6 +227,8 @@ def chunk_gated_delta_rule(
initial_state, initial_state,
output_final_state, output_final_state,
cu_seqlens, cu_seqlens,
chunk_indices,
chunk_offsets,
use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel,
) )
return o, final_state return o, final_state

View File

@@ -14,7 +14,7 @@ from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices, prepare_chunk_offsets from .index import prepare_chunk_indices, prepare_chunk_offsets
from .op import exp from .op import exp
from .utils import use_cuda_graph from .utils import FLA_CHUNK_SIZE, use_cuda_graph
NUM_WARPS = [2, 4, 8, 16] NUM_WARPS = [2, 4, 8, 16]
@@ -286,9 +286,11 @@ def chunk_gated_delta_rule_fwd_h(
gk: torch.Tensor | None = None, gk: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None, initial_state: torch.Tensor | None = None,
output_final_state: bool = False, output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64? chunk_size: int = FLA_CHUNK_SIZE,
save_new_value: bool = True, save_new_value: bool = True,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# This kernel is slightly different from fla to support Q/K with different head numbers. # This kernel is slightly different from fla to support Q/K with different head numbers.
# In fla, Q/K always have the same head number, so Hg is always equal to H. # In fla, Q/K always have the same head number, so Hg is always equal to H.
@@ -296,20 +298,15 @@ def chunk_gated_delta_rule_fwd_h(
H = u.shape[-2] H = u.shape[-2]
BT = chunk_size BT = chunk_size
chunk_indices = ( if chunk_indices is None and cu_seqlens is not None:
prepare_chunk_indices(cu_seqlens, chunk_size) chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
# N: the actual number of sequences in the batch with either equal or variable lengths # N: the actual number of sequences in the batch with either equal or variable lengths
if cu_seqlens is None: if cu_seqlens is None:
N, NT, chunk_offsets = B, triton.cdiv(T, BT), None N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
else: else:
N, NT, chunk_offsets = ( N, NT = len(cu_seqlens) - 1, len(chunk_indices)
len(cu_seqlens) - 1, if chunk_offsets is None:
len(chunk_indices), chunk_offsets = prepare_chunk_offsets(cu_seqlens, BT)
prepare_chunk_offsets(cu_seqlens, BT),
)
assert K <= 256, "current kernel does not support head dimension larger than 256." assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, V, K) h = k.new_empty(B, NT, H, V, K)

View File

@@ -146,14 +146,14 @@ def chunk_fwd_o(
g: torch.Tensor | None = None, # cumsum of log decay g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None, scale: float | None = None,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_size: int = FLA_CHUNK_SIZE, chunk_size: int = FLA_CHUNK_SIZE,
) -> torch.Tensor: ) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1] B, T, Hg, K, V = *q.shape, v.shape[-1]
H = v.shape[-2] H = v.shape[-2]
BT = chunk_size BT = chunk_size
chunk_indices = ( if chunk_indices is None and cu_seqlens is not None:
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
if scale is None: if scale is None:
scale = k.shape[-1] ** -0.5 scale = k.shape[-1] ** -0.5

View File

@@ -14,6 +14,7 @@ from vllm.triton_utils import tl, triton
from .index import prepare_chunk_indices from .index import prepare_chunk_indices
from .op import exp from .op import exp
from .utils import FLA_CHUNK_SIZE
@triton.heuristics( @triton.heuristics(
@@ -103,7 +104,8 @@ def chunk_scaled_dot_kkt_fwd(
g: torch.Tensor | None = None, g: torch.Tensor | None = None,
beta: torch.Tensor | None = None, beta: torch.Tensor | None = None,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64, chunk_indices: torch.Tensor | None = None,
chunk_size: int = FLA_CHUNK_SIZE,
output_dtype: torch.dtype = torch.float32, output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor: ) -> torch.Tensor:
r""" r"""
@@ -119,6 +121,9 @@ def chunk_scaled_dot_kkt_fwd(
cu_seqlens (torch.Tensor): cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. The cumulative sequence lengths of the input tensor.
Default: None Default: None
chunk_indices (torch.Tensor):
Pre-computed chunk indices. If None and cu_seqlens is provided,
computed internally. Default: None
chunk_size (int): chunk_size (int):
The chunk size. Default: 64. The chunk size. Default: 64.
output_dtype (torch.dtype): output_dtype (torch.dtype):
@@ -132,9 +137,8 @@ def chunk_scaled_dot_kkt_fwd(
B, T, Hg, K = k.shape B, T, Hg, K = k.shape
H = beta.shape[-1] H = beta.shape[-1]
BT = chunk_size BT = chunk_size
chunk_indices = ( if chunk_indices is None and cu_seqlens is not None:
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)

View File

@@ -162,6 +162,7 @@ def chunk_local_cumsum_scalar(
chunk_size: int, chunk_size: int,
reverse: bool = False, reverse: bool = False,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False, head_first: bool = False,
output_dtype: torch.dtype | None = torch.float, output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor: ) -> torch.Tensor:
@@ -172,10 +173,9 @@ def chunk_local_cumsum_scalar(
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2" "chunk_size must be a power of 2"
) )
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
BT = chunk_size BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
grid = (NT, B * H) grid = (NT, B * H)
@@ -199,6 +199,7 @@ def chunk_local_cumsum_vector(
chunk_size: int, chunk_size: int,
reverse: bool = False, reverse: bool = False,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False, head_first: bool = False,
output_dtype: torch.dtype | None = torch.float, output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor: ) -> torch.Tensor:
@@ -206,16 +207,13 @@ def chunk_local_cumsum_vector(
B, H, T, S = g.shape B, H, T, S = g.shape
else: else:
B, T, H, S = g.shape B, T, H, S = g.shape
BT = chunk_size
chunk_indices = (
prepare_chunk_indices(cu_seqlens, chunk_size)
if cu_seqlens is not None
else None
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
assert chunk_size == 2 ** (chunk_size.bit_length() - 1), ( assert chunk_size == 2 ** (chunk_size.bit_length() - 1), (
"chunk_size must be a power of 2" "chunk_size must be a power of 2"
) )
if chunk_indices is None and cu_seqlens is not None:
chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size)
BT = chunk_size
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
@@ -247,6 +245,7 @@ def chunk_local_cumsum(
chunk_size: int, chunk_size: int,
reverse: bool = False, reverse: bool = False,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
head_first: bool = False, head_first: bool = False,
output_dtype: torch.dtype | None = torch.float, output_dtype: torch.dtype | None = torch.float,
**kwargs, **kwargs,
@@ -257,11 +256,23 @@ def chunk_local_cumsum(
) )
if len(g.shape) == 3: if len(g.shape) == 3:
return chunk_local_cumsum_scalar( return chunk_local_cumsum_scalar(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype g,
chunk_size,
reverse,
cu_seqlens,
chunk_indices,
head_first,
output_dtype,
) )
elif len(g.shape) == 4: elif len(g.shape) == 4:
return chunk_local_cumsum_vector( return chunk_local_cumsum_vector(
g, chunk_size, reverse, cu_seqlens, head_first, output_dtype g,
chunk_size,
reverse,
cu_seqlens,
chunk_indices,
head_first,
output_dtype,
) )
else: else:
raise ValueError( raise ValueError(

View File

@@ -23,7 +23,7 @@ from .index import prepare_chunk_indices
from .l2norm import l2norm_fwd from .l2norm import l2norm_fwd
from .op import exp, log from .op import exp, log
from .solve_tril import solve_tril from .solve_tril import solve_tril
from .utils import is_amd from .utils import FLA_CHUNK_SIZE, is_amd
BT_LIST_AUTOTUNE = [32, 64, 128] BT_LIST_AUTOTUNE = [32, 64, 128]
NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32] NUM_WARPS_AUTOTUNE = [2, 4, 8, 16] if is_amd else [4, 8, 16, 32]
@@ -721,7 +721,7 @@ def chunk_kda_scaled_dot_kkt_fwd(
beta: torch.Tensor | None = None, beta: torch.Tensor | None = None,
scale: float | None = None, scale: float | None = None,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_size: int = 64, chunk_size: int = FLA_CHUNK_SIZE,
output_dtype: torch.dtype = torch.float32, output_dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
r""" r"""
@@ -1178,7 +1178,7 @@ def chunk_kda_fwd(
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
): ):
chunk_size = 64 chunk_size = FLA_CHUNK_SIZE
g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens) g = chunk_local_cumsum(g, chunk_size=chunk_size, cu_seqlens=cu_seqlens)
# the intra Aqk is kept in fp32 # the intra Aqk is kept in fp32
# the computation has very marginal effect on the entire throughput # the computation has very marginal effect on the entire throughput
@@ -1189,6 +1189,7 @@ def chunk_kda_fwd(
beta=beta, beta=beta,
scale=scale, scale=scale,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_size=chunk_size,
output_dtype=torch.float32, output_dtype=torch.float32,
) )
A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)

View File

@@ -507,6 +507,7 @@ def merge_16x16_to_64x64_inverse_kernel(
def solve_tril( def solve_tril(
A: torch.Tensor, A: torch.Tensor,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float, output_dtype: torch.dtype = torch.float,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
@@ -518,6 +519,8 @@ def solve_tril(
[B, T, H, BT], where BT should only be 16, 32, or 64. [B, T, H, BT], where BT should only be 16, 32, or 64.
cu_seqlens (torch.Tensor): cu_seqlens (torch.Tensor):
The cumulative sequence lengths of the input tensor. Default: `None`. The cumulative sequence lengths of the input tensor. Default: `None`.
chunk_indices (torch.Tensor):
Pre-computed chunk indices. Default: `None`.
output_dtype (torch.dtype): output_dtype (torch.dtype):
The dtype of the output tensor. Default: `torch.float`. The dtype of the output tensor. Default: `torch.float`.
If `None`, the output dtype will be the same as the input dtype. If `None`, the output dtype will be the same as the input dtype.
@@ -529,9 +532,8 @@ def solve_tril(
output_dtype = A.dtype if output_dtype is None else output_dtype output_dtype = A.dtype if output_dtype is None else output_dtype
B, T, H, BT = A.shape B, T, H, BT = A.shape
chunk_indices = ( if chunk_indices is None and cu_seqlens is not None:
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
)
NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
Ai = torch.zeros_like(A, dtype=output_dtype) Ai = torch.zeros_like(A, dtype=output_dtype)

View File

@@ -123,14 +123,14 @@ def recompute_w_u_fwd(
g_cumsum: torch.Tensor, g_cumsum: torch.Tensor,
A: torch.Tensor, A: torch.Tensor,
cu_seqlens: torch.Tensor | None, cu_seqlens: torch.Tensor | None,
chunk_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1] B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2] H = v.shape[-2]
BT = A.shape[-1] BT = A.shape[-1]
chunk_indices = ( if chunk_indices is None and cu_seqlens is not None:
prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None chunk_indices = prepare_chunk_indices(cu_seqlens, BT)
)
NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
BK = 64 BK = 64
BV = 64 BV = 64

View File

@@ -163,6 +163,8 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor, initial_state: torch.Tensor,
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True, use_qk_l2norm_in_kernel: bool = True,
): ):
return fi_chunk_gated_delta_rule( return fi_chunk_gated_delta_rule(
@@ -187,6 +189,8 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state: torch.Tensor, initial_state: torch.Tensor,
output_final_state: bool, output_final_state: bool,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
chunk_indices: torch.Tensor | None = None,
chunk_offsets: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = True, use_qk_l2norm_in_kernel: bool = True,
): ):
return fla_chunk_gated_delta_rule( return fla_chunk_gated_delta_rule(
@@ -198,6 +202,8 @@ class ChunkGatedDeltaRule(CustomOp):
initial_state=initial_state, initial_state=initial_state,
output_final_state=output_final_state, output_final_state=output_final_state,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
) )
@@ -959,6 +965,8 @@ class GatedDeltaNetAttention(PluggableLayer, MambaBase):
initial_state=initial_state, initial_state=initial_state,
output_final_state=True, output_final_state=True,
cu_seqlens=non_spec_query_start_loc, cu_seqlens=non_spec_query_start_loc,
chunk_indices=attn_metadata.chunk_indices,
chunk_offsets=attn_metadata.chunk_offsets,
use_qk_l2norm_in_kernel=False, use_qk_l2norm_in_kernel=False,
) )
# Init cache # Init cache

View File

@@ -63,6 +63,10 @@ class GDNAttentionMetadata:
num_accepted_tokens: torch.Tensor | None = None # shape: [batch,] num_accepted_tokens: torch.Tensor | None = None # shape: [batch,]
# Pre-computed FLA chunk metadata (avoids GPU->CPU sync in prepare_chunk_indices)
chunk_indices: torch.Tensor | None = None
chunk_offsets: torch.Tensor | None = None
# The following attributes are for triton implementation of causal_conv1d # The following attributes are for triton implementation of causal_conv1d
nums_dict: dict | None = None nums_dict: dict | None = None
batch_ptr: torch.Tensor | None = None batch_ptr: torch.Tensor | None = None
@@ -305,6 +309,26 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
assert num_accepted_tokens is not None assert num_accepted_tokens is not None
num_accepted_tokens = num_accepted_tokens[spec_sequence_masks] num_accepted_tokens = num_accepted_tokens[spec_sequence_masks]
chunk_indices: torch.Tensor | None = None
chunk_offsets: torch.Tensor | None = None
if num_prefills > 0:
# Only prefill batches use FLA chunk ops.
# Pre-compute on CPU and async-copy to GPU to avoid
# GPU→CPU sync (.tolist()) in prepare_chunk_indices.
from vllm.model_executor.layers.fla.ops.index import (
prepare_chunk_indices,
prepare_chunk_offsets,
)
from vllm.model_executor.layers.fla.ops.utils import FLA_CHUNK_SIZE
gpu_device = query_start_loc.device
chunk_indices = prepare_chunk_indices(
non_spec_query_start_loc_cpu, FLA_CHUNK_SIZE
).to(device=gpu_device, non_blocking=True)
chunk_offsets = prepare_chunk_offsets(
non_spec_query_start_loc_cpu, FLA_CHUNK_SIZE
).to(device=gpu_device, non_blocking=True)
if num_prefills > 0: if num_prefills > 0:
has_initial_state = context_lens_tensor > 0 has_initial_state = context_lens_tensor > 0
if spec_sequence_masks is not None: if spec_sequence_masks is not None:
@@ -405,6 +429,8 @@ class GDNAttentionMetadataBuilder(AttentionMetadataBuilder[GDNAttentionMetadata]
num_spec_decode_tokens=num_spec_decode_tokens, num_spec_decode_tokens=num_spec_decode_tokens,
num_actual_tokens=m.num_actual_tokens, num_actual_tokens=m.num_actual_tokens,
has_initial_state=has_initial_state, has_initial_state=has_initial_state,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
spec_query_start_loc=spec_query_start_loc, spec_query_start_loc=spec_query_start_loc,
non_spec_query_start_loc=non_spec_query_start_loc, non_spec_query_start_loc=non_spec_query_start_loc,
spec_state_indices_tensor=spec_state_indices_tensor, spec_state_indices_tensor=spec_state_indices_tensor,