D1.3: Apply transform_partitioned_tensor_layout before epilogue helpers
This commit is contained in:
@@ -9,7 +9,7 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition
|
||||
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition, transform_partitioned_tensor_layout
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
import math
|
||||
@@ -378,8 +378,13 @@ class FmhaKernel:
|
||||
|
||||
# Set up the TMEM→reg and reg→SMEM copy atoms using CUTLASS helpers
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
# Transform layout: ((MMA_ATOM_M, MMA_ATOM_N), MMA_M, MMA_N, STAGE)
|
||||
# -> ((MMA_ATOM_M, MMA_M), (MMA_ATOM_N, MMA_N), STAGE)
|
||||
tCtO = transform_partitioned_tensor_layout(tCtO_base)
|
||||
# Transform gC layout similarly (needed by the helpers)
|
||||
tCgC_xform = transform_partitioned_tensor_layout(tCgC)
|
||||
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO_base, tCgC, epi_tile, self.use_2cta_instrs
|
||||
self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs
|
||||
)
|
||||
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
|
||||
@@ -410,8 +415,8 @@ class FmhaKernel:
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
# TMA store from SMEM to GMEM
|
||||
# Partition sC and gC for TMA store
|
||||
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
|
||||
# Partition sC and gC for TMA store (using transformed gC)
|
||||
tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
|
||||
Reference in New Issue
Block a user