diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index a2188329..82bcc044 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -15,6 +15,16 @@ import cutlass.torch as ct import math +def _transform_partitioned_tensor_layout(tensor): + """Transform MMA layout: ((ATOM_M, ATOM_N), MMA_M, MMA_N, ...rest) + -> ((ATOM_M, MMA_M), (ATOM_N, MMA_N), ...rest). + Same as CUTLASS utils.gemm.sm100.transform_partitioned_tensor_layout.""" + layout = tensor.layout; shape = layout.shape; stride = layout.stride + new_shape = ((shape[0][0], shape[1]), (shape[0][1], shape[2]), *shape[3:]) + new_stride = ((stride[0][0], stride[1]), (stride[0][1], stride[2]), *stride[3:]) + return cute.make_tensor(tensor.iterator, cute.make_layout(shape=new_shape, stride=new_stride)) + + class FmhaKernel: def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True): self.head_dim = head_dim @@ -410,103 +420,98 @@ class FmhaKernel: final_o_bar.arrive_and_wait() # ============================================================ - # CORRECTION EPILOG: One-way TMEM → registers → normalize → SMEM + # CORRECTION EPILOG: One-way TMEM → registers → normalize → SMEM → GMEM # ============================================================ - # Uses paired atoms from get_tmem_load_op + get_smem_store_op - # to preserve the C-fragment layout. No TMEM write-back. - # Based on CUTLASS FMHA reference's correction_epilog pattern. - # Eliminates the 3% per-tile TMEM round-trip error. + # Follows CUTLASS epilogue_tma_store pattern exactly: + # transform_partitioned_tensor_layout → flat_divide → + # get_tmem_load_op → make_tmem_copy → partition_S → + # get_smem_store_op → make_tiled_copy_D → partition_D → + # cpasync.tma_partition → copy loop + # Eliminates the 3% per-tile TMEM round-trip error by using + # paired atoms that preserve the C-fragment layout. # ============================================================ # D5a: When normalize=False, still do one-way trip but skip 1/row_sum. if const_expr(self.normalize): inv_row_sum = Float32(1.0) / row_sum - # Step 1: logical_divide O and sC into correction sub-tiles. - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tOtO.layout) - tOcO = pv_thr.partition_C(cute.make_identity_tensor(self.pv_mma_tiler[:2])) - tOsO = pv_thr.partition_C(sC) - corr_ts = corr_tile_size # sub-tile N-dim (16 for BF16) - tOtO_i = cute.logical_divide(tCtO_base, cute.make_layout((128, corr_ts))) - tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, corr_ts))) - tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_ts))) + # Step 1: Transform partitioned tensor layouts (CUTLASS pattern) + # ((ATOM_M, ATOM_N), MMA_M, MMA_N, ...) -> ((ATOM_M, MMA_M), (ATOM_N, MMA_N), ...) + tOtO_xfm = _transform_partitioned_tensor_layout( + cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tOtO.layout)) + tCgC_xfm = _transform_partitioned_tensor_layout(tCgC) - # Step 2: Build TMEM load copy using get_tmem_load_op (paired atom). - epi_subtile = (self.epi_tile[0], corr_ts) + # Step 2: TMEM load copy (epilogue_tmem_copy_and_partition pattern) from cutlass.utils.blackwell_helpers import get_tmem_load_op as _get_tmem_load_op tmem_copy_atom = _get_tmem_load_op( - self.pv_mma_tiler, self.c_layout, self.o_dtype, self.acc_dtype, - epi_subtile, use_2cta_instrs=self.use_2cta_instrs, + self.cta_tile_shape_mnk, self.c_layout, self.o_dtype, self.acc_dtype, + epi_tile, self.use_2cta_instrs, ) - # tOtO_i has shape ((128, corr_ts), n_corr_tiles) after logical_divide. - # make_tmem_copy needs a tensor with the sub-tile layout. - # Slice to the first sub-tile to get the right layout for the copy atom. - tOtO_sub0 = tOtO_i[(None, None), 0] # first sub-tile - tiled_tmem_load_corr = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_sub0) + # flat_divide by epi_tile to create sub-tiled views + tOtO_epi = cute.flat_divide(tOtO_xfm, epi_tile) + tCgC_epi = cute.flat_divide(tCgC_xfm, epi_tile) + # make_tmem_copy with the first sub-tile shape + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_epi[(None, None, 0, 0, 0)]) + thr_t2r = tiled_copy_t2r.get_slice(sfw_idx) + # Partition source (TMEM) and destination (GMEM-derived register shape) + tTR_tAcc = thr_t2r.partition_S(tOtO_epi) + tTR_gC = thr_t2r.partition_D(tCgC_epi) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype) - # Step 3: Build SMEM store copy using get_smem_store_op (paired with TMEM load). + # Step 3: SMEM store copy (epilogue_smem_copy_and_partition pattern) smem_copy_atom = get_smem_store_op( - self.c_layout, self.o_dtype, self.acc_dtype, tiled_tmem_load_corr - ) - tiled_smem_store_corr = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load_corr) + self.c_layout, self.o_dtype, self.acc_dtype, tiled_copy_t2r) + tiled_copy_r2s = cute.make_tiled_copy_D(smem_copy_atom, tiled_copy_t2r) + thr_r2s = tiled_copy_r2s.get_slice(sfw_idx) + tRS_sC = thr_r2s.partition_D(sC) + tTR_rC = cute.make_rmem_tensor(tRS_sC[(None, None, None, 0)].shape, self.o_dtype) - # Step 4: Partition source (TMEM) and destination (SMEM) for each softmax thread. - thr_tmem_corr = tiled_tmem_load_corr.get_slice(sfw_idx) - thr_smem_corr = tiled_smem_store_corr.get_slice(sfw_idx) - # Partition the sub-tiled O for the correction loop. - tTMEM_CORRtO = thr_tmem_corr.partition_S(tOtO_i[(None, None), None]) - tSMEM_CORRsO = thr_smem_corr.partition_D(tOsO_i[(None, None), None]) - tSMEM_CORRcO = thr_smem_corr.partition_S(tOcO_i[(None, None), None]) - - # Step 5: Correction loop — for each sub-tile: TMEM → reg (normalize) → SMEM - for i in range(n_corr_tiles): - tTMEM_CORRtO_i = tTMEM_CORRtO[None, 0, 0, i] - tSMEM_CORRsO_i = tSMEM_CORRsO[None, 0, 0, i] - # Create register tensor for this sub-tile using the SMEM copy's source layout - tTMrO = cute.make_rmem_tensor(tSMEM_CORRcO[None, 0, 0, i].shape, self.acc_dtype) - - # Load O from TMEM using paired atom (preserves C-fragment layout) - cute.copy(tiled_tmem_load_corr, tTMEM_CORRtO_i, tTMrO) - - # Normalize: multiply by inv_row_sum (exact in FP32) - if const_expr(self.normalize): - for j in cutlass.range(cute.size(tTMrO), vectorize=True): - tTMrO[j] = tTMrO[j] * inv_row_sum - - # Convert to output dtype and store to SMEM via paired atom - tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype) - o_vec = tTMrO.load() - tSMrO.store(o_vec.to(self.o_dtype)) - cute.copy(tiled_smem_store_corr, tSMrO, tSMEM_CORRsO_i) - - # Fence SMEM writes and sync before TMA store - cute.arch.fence_proxy("async.shared", space="cta") - # Barrier: ensure all softmax warps have finished writing to SMEM - # before TMA store reads from it. Use a separate barrier ID. - corr_epi_bar = pipeline.NamedBarrier( - barrier_id=5, num_threads=32 * len(self.epilogue_warp_id) - ) - corr_epi_bar.arrive_and_wait() - - # Step 6: TMA store SMEM → GMEM - # The normalized O is now in sC (written by the correction epilog). - # The tma_c was created with CopyBulkTensorTileS2GOp for c (3D) and epi_s (2D SMEM layout). - # We need to partition sC and the GMEM output for the TMA copy. - # Use flat_divide on the already-partitioned tCgC (same pattern - # as CUTLASS epilogue_tma_store), then tma_partition. - tCgC_epi = cute.flat_divide(tCgC, epi_tile) + # Step 4: TMA store partition (cpasync.tma_partition for S2G) + # flat_divide tCgC for TMA partition (need the un-xfm version for tma_partition) + tCgC_epi_tma = cute.flat_divide(tCgC, epi_tile) bSG_sC, bSG_gC = cpasync.tma_partition( tma_c, 0, cute.make_layout(1), cute.group_modes(sC, 0, 2), - cute.group_modes(tCgC_epi, 0, 2), + cute.group_modes(tCgC_epi_tma, 0, 2), ) - # Group all modes >= 1 into one (CUTLASS pattern) - bSG_gC_flat = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) - # One TMA store for the full output tile - if warp_idx == self.epilogue_warp_id[0]: - cute.copy(tma_c, bSG_sC[(None, 0)], bSG_gC_flat[(None, Int32(0))]) - cute.arch.cp_async_bulk_commit_group() - cute.arch.cp_async_bulk_wait_group(0, read=True) + + # Step 5: Correction loop — for each sub-tile: TMEM → reg → normalize → SMEM + tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3]) + + for subtile_idx in range(subtile_cnt): + # Load O from TMEM (preserves C-fragment layout via paired atom) + tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # Normalize: multiply by inv_row_sum (exact in FP32) + if const_expr(self.normalize): + for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True): + tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum + + # Convert to output dtype and store to SMEM + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = acc_vec.to(self.o_dtype) + tTR_rC.store(acc_vec) + + # Store to SMEM + c_buffer = subtile_idx % self.num_c_stage + cute.copy(tiled_copy_r2s, tTR_rC, tRS_sC[(None, None, None, c_buffer)]) + + # Fence and barrier + cute.arch.fence_proxy("async.shared", space="cta") + corr_epi_bar = pipeline.NamedBarrier( + barrier_id=5, num_threads=32 * len(self.epilogue_warp_id)) + corr_epi_bar.arrive_and_wait() + + # TMA store SMEM → GMEM + if warp_idx == self.epilogue_warp_id[0]: + cute.copy(tma_c, bSG_sC[(None, c_buffer)], + bSG_gC[(None, None, None, Int32(0), Int32(0), Int32(0))]) + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + corr_epi_bar.arrive_and_wait() # D5a: Write LSE (log-softmax) when normalize=False # lse = ln(row_sum) + row_max * ln(2)