[PERF] Change GDN Attention State Layout from [N, HV, K, V] to [N, HV, V, K] (#33291)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-02-04 15:20:52 +04:00
committed by GitHub
parent 8e32690869
commit 824058076c
6 changed files with 61 additions and 61 deletions

View File

@@ -138,11 +138,11 @@ def chunk_gated_delta_rule(
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, H, K, V]` for `N` input sequences.
Initial state of shape `[N, H, V, K]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
output_final_state (Optional[bool]):
Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
Whether to output the final state of shape `[N, H, V, K]`. Default: `False`.
cu_seqlens (torch.LongTensor):
Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
consistent with the FlashAttention API.
@@ -154,7 +154,7 @@ def chunk_gated_delta_rule(
o (torch.Tensor):
Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
final_state (torch.Tensor):
Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
Final state of shape `[N, H, V, K]` if `output_final_state=True` else `None`.
Examples::
>>> import torch
@@ -168,7 +168,7 @@ def chunk_gated_delta_rule(
>>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
>>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
>>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
>>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
>>> h0 = torch.randn(B, H, V, K, dtype=torch.bfloat16, device='cuda')
>>> o, ht = chunk_gated_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -81,70 +81,70 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
NT = tl.cdiv(T, BT)
boh = i_n * NT
# [BK, BV]
b_h1 = tl.zeros([64, BV], dtype=tl.float32)
# [BV, BK]
b_h1 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 64:
b_h2 = tl.zeros([64, BV], dtype=tl.float32)
b_h2 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 128:
b_h3 = tl.zeros([64, BV], dtype=tl.float32)
b_h3 = tl.zeros([BV, 64], dtype=tl.float32)
if K > 192:
b_h4 = tl.zeros([64, BV], dtype=tl.float32)
b_h4 = tl.zeros([BV, 64], dtype=tl.float32)
# calculate offset
h += ((boh * H + i_h) * K * V).to(tl.int64)
h += ((boh * H + i_h) * V * K).to(tl.int64)
v += ((bos * H + i_h) * V).to(tl.int64)
k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
w += ((bos * H + i_h) * K).to(tl.int64)
if SAVE_NEW_VALUE:
v_new += ((bos * H + i_h) * V).to(tl.int64)
stride_v = H * V
stride_h = H * K * V
stride_h = H * V * K
stride_k = Hg * K
stride_w = H * K
if USE_INITIAL_STATE:
h0 = h0 + i_nh * K * V
h0 = h0 + i_nh * V * K
if STORE_FINAL_STATE:
ht = ht + i_nh * K * V
ht = ht + i_nh * V * K
# load initial state
if USE_INITIAL_STATE:
p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
p_h0_1 = tl.make_block_ptr(h0, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
if K > 64:
p_h0_2 = tl.make_block_ptr(
h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
h0, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
if K > 128:
p_h0_3 = tl.make_block_ptr(
h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
h0, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
if K > 192:
p_h0_4 = tl.make_block_ptr(
h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
h0, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
# main recurrence
for i_t in range(NT):
p_h1 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0)
)
tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_h2 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_h3 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_h4 = tl.make_block_ptr(
h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
h + i_t * stride_h, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
@@ -152,25 +152,25 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
b_v = tl.dot(b_w, tl.trans(b_h1).to(b_w.dtype))
if K > 64:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
b_v += tl.dot(b_w, tl.trans(b_h2).to(b_w.dtype))
if K > 128:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
b_v += tl.dot(b_w, tl.trans(b_h3).to(b_w.dtype))
if K > 192:
p_w = tl.make_block_ptr(
w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)
)
b_w = tl.load(p_w, boundary_check=(0, 1))
b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
b_v += tl.dot(b_w, tl.trans(b_h4).to(b_w.dtype))
p_v = tl.make_block_ptr(
v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)
)
@@ -207,7 +207,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k1 < K),
other=0.0,
)
b_h1 *= exp(b_gk_last1)[:, None]
b_h1 *= exp(b_gk_last1)[None, :]
if K > 64:
o_k2 = 64 + o_k1
b_gk_last2 = tl.load(
@@ -215,7 +215,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k2 < K),
other=0.0,
)
b_h2 *= exp(b_gk_last2)[:, None]
b_h2 *= exp(b_gk_last2)[None, :]
if K > 128:
o_k3 = 128 + o_k1
b_gk_last3 = tl.load(
@@ -223,7 +223,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k3 < K),
other=0.0,
)
b_h3 *= exp(b_gk_last3)[:, None]
b_h3 *= exp(b_gk_last3)[None, :]
if K > 192:
o_k4 = 192 + o_k1
b_gk_last4 = tl.load(
@@ -231,49 +231,49 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
mask=(o_k4 < K),
other=0.0,
)
b_h4 *= exp(b_gk_last4)[:, None]
b_h4 *= exp(b_gk_last4)[None, :]
b_v = b_v.to(k.dtype.element_ty)
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h1 += tl.dot(b_k, b_v)
b_h1 += tl.trans(tl.dot(b_k, b_v))
if K > 64:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h2 += tl.dot(b_k, b_v)
b_h2 += tl.trans(tl.dot(b_k, b_v))
if K > 128:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h3 += tl.dot(b_k, b_v)
b_h3 += tl.trans(tl.dot(b_k, b_v))
if K > 192:
p_k = tl.make_block_ptr(
k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)
)
b_k = tl.load(p_k, boundary_check=(0, 1))
b_h4 += tl.dot(b_k, b_v)
b_h4 += tl.trans(tl.dot(b_k, b_v))
# epilogue
if STORE_FINAL_STATE:
p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
p_ht = tl.make_block_ptr(ht, (V, K), (K, 1), (i_v * BV, 0), (BV, 64), (1, 0))
tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 64:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)
ht, (V, K), (K, 1), (i_v * BV, 64), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 128:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)
ht, (V, K), (K, 1), (i_v * BV, 128), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
if K > 192:
p_ht = tl.make_block_ptr(
ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)
ht, (V, K), (K, 1), (i_v * BV, 192), (BV, 64), (1, 0)
)
tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
@@ -312,9 +312,9 @@ def chunk_gated_delta_rule_fwd_h(
)
assert K <= 256, "current kernel does not support head dimension larger than 256."
h = k.new_empty(B, NT, H, K, V)
h = k.new_empty(B, NT, H, V, K)
final_state = (
k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
k.new_empty(N, H, V, K, dtype=torch.float32) if output_final_state else None
)
v_new = torch.empty_like(u) if save_new_value else None

View File

@@ -85,7 +85,7 @@ def chunk_fwd_kernel_o(
k += (bos * Hg + i_h // (H // Hg)) * K
v += (bos * H + i_h) * V
o += (bos * H + i_h) * V
h += (i_tg * H + i_h).to(tl.int64) * K * V
h += (i_tg * H + i_h).to(tl.int64) * V * K
b_o = tl.zeros([BT, BV], dtype=tl.float32)
b_A = tl.zeros([BT, BT], dtype=tl.float32)
@@ -98,17 +98,17 @@ def chunk_fwd_kernel_o(
k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)
)
p_h = tl.make_block_ptr(
h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)
h, (V, K), (K, 1), (i_v * BV, i_k * BK), (BV, BK), (1, 0)
)
# [BT, BK]
b_q = tl.load(p_q, boundary_check=(0, 1))
# [BK, BT]
b_k = tl.load(p_k, boundary_check=(0, 1))
# [BK, BV]
# [BV, BK]
b_h = tl.load(p_h, boundary_check=(0, 1))
# [BT, BK] @ [BK, BV] -> [BT, BV]
b_o += tl.dot(b_q, b_h)
b_o += tl.dot(b_q, tl.trans(b_h))
# [BT, BK] @ [BK, BT] -> [BT, BT]
b_A += tl.dot(b_q, b_k)

View File

@@ -97,9 +97,9 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_k[:, None] & mask_v[None, :]
mask_h = mask_v[:, None] & mask_k[None, :]
b_h = tl.zeros([BK, BV], dtype=tl.float32)
b_h = tl.zeros([BV, BK], dtype=tl.float32)
if USE_INITIAL_STATE:
if IS_CONTINUOUS_BATCHING:
if IS_SPEC_DECODING:
@@ -115,8 +115,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
return
p_h0 = h0 + state_idx * stride_init_state_token
else:
p_h0 = h0 + bos * HV * K * V
p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
p_h0 = h0 + bos * HV * V * K
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
for i_t in range(0, T):
@@ -128,24 +128,24 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
# [BK, BV]
# [BV, BK]
if not IS_KDA:
b_g = tl.load(p_g).to(tl.float32)
b_h *= exp(b_g)
else:
b_gk = tl.load(p_gk).to(tl.float32)
b_h *= exp(b_gk[:, None])
b_h *= exp(b_gk[None, :])
# [BV]
b_v -= tl.sum(b_h * b_k[:, None], 0)
b_v -= tl.sum(b_h * b_k[None, :], 1)
if IS_BETA_HEADWISE:
b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
else:
b_beta = tl.load(p_beta).to(tl.float32)
b_v *= b_beta
# [BK, BV]
b_h += b_k[:, None] * b_v[None, :]
# [BV, BK]
b_h += b_v[:, None] * b_k[None, :]
# [BV]
b_o = tl.sum(b_h * b_q[:, None], 0)
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
# keep the states for multi-query tokens
@@ -157,11 +157,11 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
# Only store if state index is valid (not PAD_SLOT_ID)
if final_state_idx >= 0:
p_ht = ht + final_state_idx * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
else:
p_ht = ht + (bos + i_t) * stride_final_state_token
p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
p_q += H * K
@@ -202,7 +202,7 @@ def fused_recurrent_gated_delta_rule_fwd(
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)
@@ -318,7 +318,7 @@ def fused_recurrent_gated_delta_rule(
Scale factor for the RetNet attention scores.
If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
initial_state (Optional[torch.Tensor]):
Initial state of shape `[N, HV, K, V]` for `N` input sequences.
Initial state of shape `[N, HV, V, K]` for `N` input sequences.
For equal-length input sequences, `N` equals the batch size `B`.
Default: `None`.
inplace_final_state: bool:
@@ -336,7 +336,7 @@ def fused_recurrent_gated_delta_rule(
o (torch.Tensor):
Outputs of shape `[B, T, HV, V]`.
final_state (torch.Tensor):
Final state of shape `[N, HV, K, V]`.
Final state of shape `[N, HV, V, K]`.
Examples::
>>> import torch
@@ -350,7 +350,7 @@ def fused_recurrent_gated_delta_rule(
>>> v = torch.randn(B, T, HV, V, device='cuda')
>>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
>>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
>>> h0 = torch.randn(B, HV, K, V, device='cuda')
>>> h0 = torch.randn(B, HV, V, K, device='cuda')
>>> o, ht = fused_gated_recurrent_delta_rule(
q, k, v, g, beta,
initial_state=h0,

View File

@@ -55,7 +55,7 @@ def fused_recurrent_kda_fwd(
if inplace_final_state:
final_state = initial_state
else:
final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
final_state = q.new_empty(T, HV, V, K, dtype=initial_state.dtype)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = final_state.stride(0)

View File

@@ -191,8 +191,8 @@ class MambaStateShapeCalculator:
temporal_state_shape = (
divide(num_v_heads, tp_world_size),
head_k_dim,
head_v_dim,
head_k_dim,
)
return conv_state_shape, temporal_state_shape