D1.3: Remove NO-OP round-trip, keep normalize + epilogue_tma_store
This commit is contained in:
@@ -9,7 +9,6 @@ from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
||||
from cutlass.utils import LayoutEnum
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
from cutlass.utils.gemm.sm100 import epilogue_tmem_copy_and_partition, epilogue_smem_copy_and_partition, transform_partitioned_tensor_layout
|
||||
import cuda.bindings.driver as cuda
|
||||
import cutlass.torch as ct
|
||||
import math
|
||||
@@ -370,77 +369,51 @@ class FmhaKernel:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === Correction epilog: TMEM -> reg (normalize) -> SMEM -> GMEM ===
|
||||
# Full pipeline using CUTLASS epilogue helpers for correct layout handling.
|
||||
# Replaces the broken NO-OP TMEM round-trip + normalize approach.
|
||||
# === O normalization: TMEM -> reg (scale by 1/row_sum) -> TMEM ===
|
||||
# Uses hand-constructed Ld32x32bOp/St32x32bOp atoms (same as correction_rescale).
|
||||
# The layout mismatch in these atoms introduces ~3% error per round-trip,
|
||||
# but the correction_rescale atoms (same construction) already use this path.
|
||||
# TODO: Replace with get_tmem_load_op-derived atoms for zero error.
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
|
||||
)
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO_i_ = tTMrO[None, i]
|
||||
tTMrO_i_layout = cute.composition(
|
||||
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
|
||||
tTMEM_LOADtO_i = cute.make_tensor(
|
||||
tTMEM_LOADtO.iterator + i * corr_tile_size,
|
||||
tTMEM_LOADtO.layout,
|
||||
)
|
||||
tTMEM_STOREtO_i = cute.make_tensor(
|
||||
tTMEM_STOREtO.iterator + i * corr_tile_size,
|
||||
tTMEM_STOREtO.layout,
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
|
||||
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
|
||||
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Epilogue: TMEM → SMEM → GMEM via TMA store.
|
||||
# Uses epilogue_tmem_copy_and_partition (get_tmem_load_op) internally.
|
||||
# Since O is already normalized in TMEM, we apply identity epilogue_op.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
tCtO = transform_partitioned_tensor_layout(tCtO_base)
|
||||
tCgC_xform = transform_partitioned_tensor_layout(tCgC)
|
||||
|
||||
# TMEM->reg copy (uses get_tmem_load_op for correct layout)
|
||||
tiled_copy_t2r, tTR_tAcc, tTR_rAcc = epilogue_tmem_copy_and_partition(
|
||||
self, sfw_idx, tCtO, tCgC_xform, epi_tile, self.use_2cta_instrs
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
)
|
||||
# reg->SMEM copy (uses get_smem_store_op for correct layout)
|
||||
tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype)
|
||||
tiled_copy_r2s, tRS_rC, tRS_sC = epilogue_smem_copy_and_partition(
|
||||
self, tiled_copy_t2r, tTR_rC, sfw_idx, sC
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
)
|
||||
# TMA SMEM->GMEM partition
|
||||
tCgC_epi = cute.flat_divide(tCgC_xform, epi_tile)
|
||||
bSG_sC, bSG_gC = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
|
||||
# Wait for accumulator buffer
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
acc_pipe.consumer_wait(acc_cons_st)
|
||||
|
||||
# Process subtiles
|
||||
tTR_tAcc_g = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc))
|
||||
bSG_gC_g = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC))
|
||||
subtile_cnt = cute.size(tTR_tAcc_g.shape, mode=[3])
|
||||
epilog_sync_barrier = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)))
|
||||
|
||||
for subtile_idx in range(subtile_cnt):
|
||||
# Load from TMEM
|
||||
tTR_tAcc_mn = tTR_tAcc_g[(None, None, None, subtile_idx)]
|
||||
cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc)
|
||||
|
||||
# NORMALIZE: O *= 1/row_sum (the key addition vs. epilogue_tma_store)
|
||||
for j in cutlass.range(cute.size(tTR_rAcc), vectorize=True):
|
||||
tTR_rAcc[j] = tTR_rAcc[j] * inv_row_sum
|
||||
|
||||
# Convert FP32 -> BF16
|
||||
acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load()
|
||||
tRS_rC.store(acc_vec.to(self.c_dtype))
|
||||
|
||||
# Store to SMEM
|
||||
c_buffer = subtile_idx % self.num_c_stage
|
||||
cute.copy(tiled_copy_r2s, tRS_rC, tRS_sC[(None, None, None, c_buffer)])
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
|
||||
# TMA store SMEM -> GMEM
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, bSG_sC[(None, c_buffer)], bSG_gC_g[(None, subtile_idx)])
|
||||
c_pipe.producer_commit()
|
||||
c_pipe.producer_acquire()
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
|
||||
epilog_sync_barrier.arrive_and_wait()
|
||||
|
||||
# Release accumulator buffer
|
||||
with cute.arch.elect_one():
|
||||
acc_pipe.consumer_release(acc_cons_st)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
|
||||
Reference in New Issue
Block a user