diff --git a/tests/unit/test_fmha_v3_stage_c2.py b/tests/unit/test_fmha_v3_stage_c2.py index 824d5197..9ea606dc 100644 --- a/tests/unit/test_fmha_v3_stage_c2.py +++ b/tests/unit/test_fmha_v3_stage_c2.py @@ -322,7 +322,9 @@ class FmhaV3StageC2: if warp_idx >= len(self.softmax_warp_ids) and warp_idx < len(self.softmax_warp_ids) + len(self.correction_warp_ids): tmem.wait_for_alloc() corr_idx = tidx % (32 * len(self.correction_warp_ids)) - # Vec load setup + # Vec load setup (compute cS from scratch for correction warps) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout) tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)