Fix: 3D coords for TMA copy (bSG_sC has 3 modes)
This commit is contained in:
@@ -570,7 +570,7 @@ class FmhaKernel:
|
||||
)
|
||||
c_pipe.producer_acquire()
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC[(None, Int32(0))], bSG_gC[(None, Int32(0))])
|
||||
cute.copy(tma_c, bSG_sC[(None, None, Int32(0))], bSG_gC[(None, None, Int32(0))])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_acquire()
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
|
||||
Reference in New Issue
Block a user