fix: correct bSG_gC indexing (6 modes)
This commit is contained in:
@@ -392,7 +392,7 @@ class FmhaV3StageCMulti:
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
epi_bar.arrive_and_wait()
|
||||
cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, 0, 0)])
|
||||
cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC[(None, 0, 0, 0, 0, 0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user