D1.5: Fix bSG_gC slicing - group trailing modes (CUTLASS pattern)

This commit is contained in:
2026-05-24 01:41:52 +00:00
parent f2ab5790e8
commit 699c646497

View File

@@ -512,8 +512,9 @@ class FmhaKernel:
# TMA store SMEM → GMEM
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, bSG_sC[(None, c_buffer)],
bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))])
# Group trailing modes and slice (CUTLASS pattern)
bSG_gC_flat = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC_flat[(None, Int32(0))])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
corr_epi_bar.arrive_and_wait()