fix: use gC not tCgC for TMA partition, group modes 0-3

This commit is contained in:
2026-05-23 01:20:52 +00:00
parent 420ed0c5d8
commit f4c474ced9

View File

@@ -460,7 +460,7 @@ class FmhaV3StageCMulti:
cute.arch.fence_proxy("async.shared", space="cta")
# TMA store: SMEM → GMEM (reuse existing tma_c from kernel setup)
# TMA store: SMEM → GMEM
# Sync all softmax warps before TMA store
softmax_all_bar = pipeline.NamedBarrier(
barrier_id=5, num_threads=32 * len(self.epilogue_warp_id)
@@ -468,10 +468,11 @@ class FmhaV3StageCMulti:
softmax_all_bar.arrive_and_wait()
# Partition SMEM and GMEM for TMA store
epi_s = cute.select(self.c_smem_s, mode=[0, 1])
tCsC, tCgC_tma = cpasync.tma_partition(
tma_c, 0, cute.make_layout((1,)),
cute.group_modes(sC, 0, 2),
cute.group_modes(tCgC, 0, 3),
cute.group_modes(gC, 0, 3),
)
cute.copy(tma_c, tCsC, tCgC_tma)
cute.arch.cp_async_bulk_commit_group()