From 1c74d02adbb76ff0338f4a156e6992ca7bf5fb85 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 20:52:42 +0000 Subject: [PATCH] D1.3: Apply transform_partitioned_tensor_layout before epilogue helpers --- dsv4/kernels/attention/fmha.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index e135cdc7..607072ad 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -9,7 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset -from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition +from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition, transform_partitioned_tensor_layout import cuda.bindings.driver as cuda import cutlass.torch as ct import math @@ -378,8 +378,13 @@ class FmhaKernel: # Set up the TMEM→reg and reg→SMEM copy atoms using CUTLASS helpers tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + # Transform layout: ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, STAGE) + # -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), STAGE) + tCtO = transform_partitioned_tensor_layout(tCtO_base) + # Transform gC layout similarly (needed by the helpers) + tCgC_xform = transform_partitioned_tensor_layout(tCgC) tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition( - self, sfw_idx, tCtO_base, tCgC, epi_tile, self.use_2cta_instrs + self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs ) tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition( @@ -410,8 +415,8 @@ class FmhaKernel: cute.arch.fence_proxy("async.shared", space="cta") # TMA store from SMEM to GMEM - # Partition sC and gC for TMA store - tCgC_epi = cute.flat_divide(tCgC, epi_tile) + # Partition sC and gC for TMA store (using transformed gC) + tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile) bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2),