Fix: 3D coords for TMA copy (bSG_sC has 3 modes)

This commit is contained in:
2026-05-27 05:00:39 +00:00
parent b0ebf41ee3
commit 4652cab8b4

View File

@@ -570,7 +570,7 @@ class FmhaKernel:
)
c_pipe.producer_acquire()
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, bSG_sC[(None, Int32(0))], bSG_gC[(None, Int32(0))])
cute.copy(tma_c, bSG_sC[(None, None, Int32(0))], bSG_gC[(None, None, Int32(0))])
c_pipe.producer_commit()
c_pipe.producer_acquire()
epilog_sync_barrier.arrive_and_wait()