D1.3: Remove NO-OP round-trip, keep normalize + epilogue_tma_store

This commit is contained in:
2026-05-23 20:56:13 +00:00
parent 820d6921d9
commit b926a0e806

View File

@@ -9,7 +9,6 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition, transform_partitioned_tensor_layout
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
@@ -370,77 +369,51 @@ class FmhaKernel:
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# === Correction epilog: TMEM -> reg (normalize) -> SMEM -> GMEM ===
# Full pipeline using CUTLASS epilogue helpers for correct layout handling.
# Replaces the broken NO-OP TMEM round-trip + normalize approach.
# === O normalization: TMEM -> reg (scale by 1/row_sum) -> TMEM ===
# Uses hand-constructed Ld32x32bOp/St32x32bOp atoms (same as correction_rescale).
# The layout mismatch in these atoms introduces ~3% error per round-trip,
# but the correction_rescale atoms (same construction) already use this path.
# TODO: Replace with get_tmem_load_op-derived atoms for zero error.
inv_row_sum = Float32(1.0) / row_sum
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Epilogue: TMEM → SMEM → GMEM via TMA store.
# Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) internally.
# Since O is already normalized in TMEM, we apply identity epilogue_op.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
tCtO = transform_partitioned_tensor_layout(tCtO_base)
tCgC_xform = transform_partitioned_tensor_layout(tCgC)
# TMEM->reg copy (uses get_tmem_load_op for correct layout)
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition(
self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
# reg->SMEM copy (uses get_smem_store_op for correct layout)
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
0, const_expr(lambda x: x), (0, 0, 0),
acc_cons_st, acc_pipe, c_pipe,
)
# TMA SMEM->GMEM partition
tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile)
bSG_sC, bSG_gC = cpasync.tma_partition(
tma_c, 0, cute.make_layout(1),
cute.group_modes(sC, 0, 2),
cute.group_modes(tCgC_epi, 0, 2),
)
# Wait for accumulator buffer
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
acc_pipe.consumer_wait(acc_cons_st)
# Process subtiles
tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
bSG_gC_g = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3])
epilog_sync_barrier = pipeline.NamedBarrier(
barrier_id=self.epilog_sync_bar_id,
num_threads=32 * len(self.epilogue_warp_id),
)
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)))
for subtile_idx in range(subtile_cnt):
# Load from TMEM
tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)]
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
# NORMALIZE: O *= 1/row_sum (the key addition vs. epilogue_tma_store)
for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True):
tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum
# Convert FP32 -> BF16
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
tRS_rC.store(acc_vec.to(self.c_dtype))
# Store to SMEM
c_buffer = subtile_idx % self.num_c_stage
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
cute.arch.fence_proxy("async.shared", space="cta")
epilog_sync_barrier.arrive_and_wait()
# TMA store SMEM -> GMEM
if warp_idx == self.epilogue_warp_id[0]:
cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC_g[(None, subtile_idx)])
c_pipe.producer_commit()
c_pipe.producer_acquire()
epilog_sync_barrier.arrive_and_wait()
epilog_sync_barrier.arrive_and_wait()
# Release accumulator buffer
with cute.arch.elect_one():
acc_pipe.consumer_release(acc_cons_st)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)