[Perf] Fix jit compiles at runtime of fla gated delta rule (#25432)

Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Corey Lowman
2025-09-23 23:16:13 -04:00
committed by GitHub
parent c30b405b8f
commit d747c2ef18

View File

@@ -40,8 +40,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
ssm_state_indices, ssm_state_indices,
num_accepted_tokens, num_accepted_tokens,
scale, scale,
N: tl.constexpr, # num of sequences N: tl.int64, # num of sequences
T: tl.constexpr, # num of tokens T: tl.int64, # num of tokens
B: tl.constexpr, B: tl.constexpr,
H: tl.constexpr, H: tl.constexpr,
HV: tl.constexpr, HV: tl.constexpr,