[Perf] Fix jit compiles at runtime of fla gated delta rule (#25432)
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -40,8 +40,8 @@ def fused_recurrent_gated_delta_rule_fwd_kernel(
|
|||||||
ssm_state_indices,
|
ssm_state_indices,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
scale,
|
scale,
|
||||||
N: tl.constexpr, # num of sequences
|
N: tl.int64, # num of sequences
|
||||||
T: tl.constexpr, # num of tokens
|
T: tl.int64, # num of tokens
|
||||||
B: tl.constexpr,
|
B: tl.constexpr,
|
||||||
H: tl.constexpr,
|
H: tl.constexpr,
|
||||||
HV: tl.constexpr,
|
HV: tl.constexpr,
|
||||||
|
|||||||
Reference in New Issue
Block a user