From f59fd07ba7b180097240bb4fc0ed9ee5ca249416 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 00:52:57 +0000 Subject: [PATCH] D1.5: Use group_modes on sC for 2D TMA view (preserves swizzle) --- dsv4/kernels/attention/fmha.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index b5ca201b..01a949a4 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -490,11 +490,12 @@ class FmhaKernel: # Step 6: TMA store SMEM → GMEM # The normalized O is now in sC (written by the correction epilog). - # The tma_c was created with (c, epi_s, epi_tile) where epi_s = select(c_smem_s, mode=[0,1]). - # We need to partition sC and the GMEM output for the TMA copy. - sC_epi_layout = cute.select(self.c_smem_s, mode=[0, 1]) # 2D SMEM layout - sC_epi = cute.make_tensor(sC.iterator, sC_epi_layout) # 2D view of sC - gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) # 2D GMEM tile + # Use group_modes to create 2D views of sC and gC for TMA partition. + # sC has a swizzled layout; group_modes preserves the swizzle in the pointer. + sC_epi = cute.group_modes(sC, 0, 2) # 2D view of swizzled SMEM + gC_epi = cute.group_modes( + cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))), 0, 2 + ) bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), sC_epi, gC_epi,