diff --git a/tests/unit/test_fmha_v3_stage_c_full.py b/tests/unit/test_fmha_v3_stage_c_full.py index 93706a73..921420dd 100644 --- a/tests/unit/test_fmha_v3_stage_c_full.py +++ b/tests/unit/test_fmha_v3_stage_c_full.py @@ -17,7 +17,8 @@ HEAD_DIM = 64 class FmhaV3StageC: def __init__(self, s_k=128, scale_softmax=None): - self.s_k = s_k self.acc_dtype = Float32; self.qk_acc_dtype = Float32 + self.s_k = s_k + self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE