From 9705b079690adc76265d36e86e0b5c4b1b3b4e76 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 00:48:18 +0000 Subject: [PATCH] D1.5: Fix TMA store - use group_modes on sC and tCgC --- dsv4/kernels/attention/fmha.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index dabd0aed..57ef90b0 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -492,11 +492,13 @@ class FmhaKernel: # The normalized O is now in sC (written by the correction epilog). # Use the same TMA store pattern as the CUTLASS FMHA reference. # Partition sC and gC for the bulk TMA copy. - sC_epi = cute.select(sC, mode=[0, 1]) # 2D view for TMA - gC_epi = cute.local_tile(mC, epi_tile, (Int32(0), Int32(0))) # 2D output tile + # Get 2D views of sC and gC for TMA. + sC_epi = sC # sC already has c_smem_s layout, TMA can handle it + gC_epi = tCgC # Use the pre-partitioned GMEM tensor bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), - sC_epi, gC_epi, + cute.group_modes(sC, 0, 2), + cute.group_modes(gC_epi, 0, 2), ) # One TMA store for the full output tile if warp_idx == self.epilogue_warp_id[0]: