[PERF] Speed-up of GDN attention decode part (Qwen3-Next) (#31722)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -189,7 +189,7 @@ def fused_recurrent_gated_delta_rule_fwd(
|
||||
B, T, H, K, V = *k.shape, v.shape[-1]
|
||||
HV = v.shape[2]
|
||||
N = B if cu_seqlens is None else len(cu_seqlens) - 1
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)
|
||||
BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)
|
||||
NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
|
||||
assert NK == 1, "NK > 1 is not supported yet"
|
||||
num_stages = 3
|
||||
|
||||
Reference in New Issue
Block a user