diff --git a/tests/unit/test_fmha_v3_stage_c.py b/tests/unit/test_fmha_v3_stage_c.py index 652b1cef..c738df23 100644 --- a/tests/unit/test_fmha_v3_stage_c.py +++ b/tests/unit/test_fmha_v3_stage_c.py @@ -56,6 +56,7 @@ class FmhaV3StageCMulti: def __init__(self, s_k=128, scale_softmax=None): # s_k MUST equal actual sequence length n. self.s_k = s_k + self.n_kv_tiles = s_k // 128 self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1