From f421fb6fb1bb959efa755985943bd34b360eb754 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 01:20:33 +0000 Subject: [PATCH] D1.5: Fix TMA store - use 3D tile for local_tile on 3D mC --- dsv4/kernels/attention/fmha.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 01a949a4..2722e10a 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -490,15 +490,16 @@ class FmhaKernel: # Step 6: TMA store SMEM → GMEM # The normalized O is now in sC (written by the correction epilog). - # Use group_modes to create 2D views of sC and gC for TMA partition. - # sC has a swizzled layout; group_modes preserves the swizzle in the pointer. - sC_epi = cute.group_modes(sC, 0, 2) # 2D view of swizzled SMEM - gC_epi = cute.group_modes( - cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))), 0, 2 - ) + # The tma_c was created with CopyBulkTensorTileS2GOp for c (3D) and epi_s (2D SMEM layout). + # We need to partition sC and the GMEM output for the TMA copy. + # Get 2D views: sC has 4D layout ((128,16),1,(4,2),1), group to 2D. + # gC: use local_tile with 3D tile and 3D coordinate. + epi_tile_3d = (epi_tile[0], epi_tile[1], 1) + gC_epi = cute.local_tile(mC, epi_tile_3d, (Int32(0), Int32(0), Int32(0))) bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), - sC_epi, gC_epi, + cute.group_modes(sC, 0, 2), + cute.group_modes(gC_epi, 0, 2), ) # One TMA store for the full output tile if warp_idx == self.epilogue_warp_id[0]: