D1.5: Fix TMA store - use existing gC partition

This commit is contained in:
2026-05-24 00:43:35 +00:00
parent cb1a9d9171
commit f305aa4884

View File

@@ -491,8 +491,8 @@ class FmhaKernel:
# Step 6: TMA store SMEM → GMEM
# The normalized O is now in sC (written by the correction epilog).
# TMA store from sC to the output tensor in GMEM.
# Use cpasync.tma_partition to set up the SMEM/GMEM partition.
gC = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0)))
# Use the pre-partitioned tCgC (GMEM partition) and sC (SMEM buffer).
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler,(None,0,None)),(None,None,None))
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),