D1.5: Fix TMA store - use 3D tile for local_tile on 3D mC

This commit is contained in:
2026-05-24 01:20:33 +00:00
parent 865832f669
commit 52a5aa61bc

View File

@@ -490,15 +490,16 @@ class FmhaKernel:
# Step 6: TMA store SMEM → GMEM
# The normalized O is now in sC (written by the correction epilog).
# Use group_modes to create 2D views of sC and gC for TMA partition.
# sC has a swizzled layout; group_modes preserves the swizzle in the pointer.
sC_epi = cute.group_modes(sC, 0, 2) # 2D view of swizzled SMEM
gC_epi = cute.group_modes(
cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))), 0, 2
)
# The tma_c was created with CopyBulkTensorTileS2GOp for c (3D) and epi_s (2D SMEM layout).
# We need to partition sC and the GMEM output for the TMA copy.
# Get 2D views: sC has 4D layout ((128,16),1,(4,2),1), group to 2D.
# gC: use local_tile with 3D tile and 3D coordinate.
epi_tile_3d = (epi_tile[0], epi_tile[1], 1)
gC_epi = cute.local_tile(mC, epi_tile_3d, (Int32(0), Int32(0), Int32(0)))
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
sC_epi, gC_epi,
cute.group_modes(sC, 0, 2),
cute.group_modes(gC_epi, 0, 2),
)
# One TMA store for the full output tile
if warp_idx == self.epilogue_warp_id[0]: