From 540399eca3d016a584e2df18e522f84a074e6b46 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Fri, 22 May 2026 09:52:59 +0000 Subject: [PATCH] fix: define cS and tScS in correction warps (not visible across if blocks) --- tests/unit/test_fmha_v3_stage_c2.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_fmha_v3_stage_c2.py b/tests/unit/test_fmha_v3_stage_c2.py index 824d5197..9ea606dc 100644 --- a/tests/unit/test_fmha_v3_stage_c2.py +++ b/tests/unit/test_fmha_v3_stage_c2.py @@ -322,7 +322,9 @@ class FmhaV3StageC2: if warp_idx >= len(self.softmax_warp_ids) and warp_idx < len(self.softmax_warp_ids) + len(self.correction_warp_ids): tmem.wait_for_alloc() corr_idx = tidx % (32 * len(self.correction_warp_ids)) - # Vec load setup + # Vec load setup (compute cS from scratch for correction warps) + cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1])) + tScS = qk_thr.partition_C(cS) tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2))) tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout) tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)