D1.3: Apply transform_partitioned_tensor_layout before epilogue helpers

This commit is contained in:
2026-05-23 20:52:42 +00:00
parent 1cf7140ea3
commit 1c74d02adb

View File

@@ -9,7 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition, transform_partitioned_tensor_layout
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
@@ -378,8 +378,13 @@ class FmhaKernel:
# Set up the TMEM→reg and reg→SMEM copy atoms using CUTLASS helpers
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
# Transform layout: ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, STAGE)
# -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), STAGE)
tCtO = transform_partitioned_tensor_layout(tCtO_base)
# Transform gC layout similarly (needed by the helpers)
tCgC_xform = transform_partitioned_tensor_layout(tCgC)
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition(
self, sfw_idx, tCtO_base, tCgC, epi_tile, self.use_2cta_instrs
self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs
)
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
@@ -410,8 +415,8 @@ class FmhaKernel:
cute.arch.fence_proxy("async.shared", space="cta")
# TMA store from SMEM to GMEM
# Partition sC and gC for TMA store
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
# Partition sC and gC for TMA store (using transformed gC)
tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),