fix: define cS and tScS in correction warps (not visible across if blocks)

This commit is contained in:
2026-05-22 09:52:59 +00:00
parent ee859099bd
commit 540399eca3

View File

@@ -322,7 +322,9 @@ class FmhaV3StageC2:
if warp_idx >= len(self.softmax_warp_ids) and warp_idx < len(self.softmax_warp_ids) + len(self.correction_warp_ids):
tmem.wait_for_alloc()
corr_idx = tidx % (32 * len(self.correction_warp_ids))
# Vec load setup
# Vec load setup (compute cS from scratch for correction warps)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2)))
tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout)
tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)