[PERF] Speed-up of GDN attention decode part (Qwen3-Next) (#31722)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2026-01-06 21:32:46 +04:00
committed by GitHub
parent 4c73be14e0
commit 22dffca982

View File

@@ -189,7 +189,7 @@ def fused_recurrent_gated_delta_rule_fwd(
B, T, H, K, V = *k.shape, v.shape[-1] B, T, H, K, V = *k.shape, v.shape[-1]
HV = v.shape[2] HV = v.shape[2]
N = B if cu_seqlens is None else len(cu_seqlens) - 1 N = B if cu_seqlens is None else len(cu_seqlens) - 1
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
assert NK == 1, "NK > 1 is not supported yet" assert NK == 1, "NK > 1 is not supported yet"
num_stages = 3 num_stages = 3