D1.5: Use group_modes on sC for 2D TMA view (preserves swizzle)

This commit is contained in:
2026-05-24 00:52:57 +00:00
parent 577066bb7f
commit f59fd07ba7

View File

@@ -490,11 +490,12 @@ class FmhaKernel:
# Step 6: TMA store SMEM → GMEM
# The normalized O is now in sC (written by the correction epilog).
# The tma_c was created with (c, epi_s, epi_tile) where epi_s = select(c_smem_s, mode=[0,1]).
# We need to partition sC and the GMEM output for the TMA copy.
sC_epi_layout = cute.select(self.c_smem_s, mode=[0, 1]) # 2D SMEM layout
sC_epi = cute.make_tensor(sC.iterator, sC_epi_layout) # 2D view of sC
gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) # 2D GMEM tile
# Use group_modes to create 2D views of sC and gC for TMA partition.
# sC has a swizzled layout; group_modes preserves the swizzle in the pointer.
sC_epi = cute.group_modes(sC, 0, 2) # 2D view of swizzled SMEM
gC_epi = cute.group_modes(
cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))), 0, 2
)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
sC_epi, gC_epi,