fix: correction_epilog with get_tmem_load_op paired atoms, no TMEM round-trip
This commit is contained in:
@@ -130,10 +130,75 @@ class FmhaV3StageCMulti:
|
||||
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
|
||||
epi_s = cute.select(self.c_smem_s,mode=[0,1])
|
||||
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
# Pre-compute paired TMEM load atom for correction_epilog
|
||||
epi_corr_tile_size = 32 * 8 // self.o_dtype.width # 16 for BF16
|
||||
epi_subtile = (self.epi_tile[0], epi_corr_tile_size)
|
||||
tmem_load_epi_atom = utils.sm100.get_tmem_load_op(
|
||||
self.pv_mma_tiler, self.c_layout, self.o_dtype, self.acc_dtype,
|
||||
epi_subtile, use_2cta_instrs=False,
|
||||
)
|
||||
|
||||
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile,tmem_load_epi_atom).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
def _correction_epilog(self, pv_thr, tOtO, scale, sC, tCgC, tma_c, sfw_idx, tidx, warp_idx, epi_tile, tmem_load_epi_atom):
|
||||
"""CUTLASS correction_epilog: read O from TMEM, normalize, convert, write SMEM→GMEM."""
|
||||
epi_corr_tile_size = 32 * 8 // self.o_dtype.width # 16 for BF16
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
tOsO = pv_thr.partition_C(sC)
|
||||
|
||||
tOtO_i = cute.logical_divide(tOtO, cute.make_layout((128, epi_corr_tile_size)))
|
||||
tOcO_i = cute.logical_divide(tOcO, cute.make_layout((128, epi_corr_tile_size)))
|
||||
tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, epi_corr_tile_size)))
|
||||
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_epi_atom, tOtO_i[(None, None), 0])
|
||||
smem_copy_atom = utils.sm100.get_smem_store_op(
|
||||
self.c_layout, self.o_dtype, self.acc_dtype, tiled_tmem_load
|
||||
)
|
||||
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load)
|
||||
|
||||
thr_tmem_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtO = thr_tmem_load.partition_S(tOtO_i[(None, None), None])
|
||||
tTMEM_LOADsO = thr_tmem_load.partition_D(tOsO_i[(None, None), None])
|
||||
tTMEM_LOADcO = thr_tmem_load.partition_D(tOcO_i[(None, None), None])
|
||||
|
||||
n_corr_tiles = self.pv_mma_tiler[1] // epi_corr_tile_size
|
||||
for i in range(n_corr_tiles):
|
||||
tTMrO = cute.make_rmem_tensor(tTMEM_LOADcO[None, 0, 0, i].shape, self.acc_dtype)
|
||||
cute.copy(tiled_tmem_load, tTMEM_LOADtO[None, 0, 0, i], tTMrO)
|
||||
for j in range(cute.size(tTMrO)):
|
||||
tTMrO[j] = tTMrO[j] * scale
|
||||
tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype)
|
||||
o_vec = tTMrO.load()
|
||||
tSMrO.store(o_vec.to(self.o_dtype))
|
||||
cute.copy(tiled_smem_store, tSMrO, tTMEM_LOADsO[None, 0, 0, i])
|
||||
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
|
||||
# TMA store SMEM → GMEM (same pattern as epilogue_tma_store)
|
||||
tCgC_epi = cute.flat_divide(tCgC, epi_tile)
|
||||
tCsC, tCgC_tma = cpasync.tma_partition(
|
||||
tma_c, 0, cute.make_layout(1),
|
||||
cute.group_modes(sC, 0, 2),
|
||||
cute.group_modes(tCgC_epi, 0, 2),
|
||||
)
|
||||
|
||||
epilog_sync_bar = pipeline.NamedBarrier(
|
||||
barrier_id=self.epilog_sync_bar_id,
|
||||
num_threads=32 * len(self.epilogue_warp_id),
|
||||
)
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
|
||||
c_buffer = 0
|
||||
if warp_idx == self.epilogue_warp_id[0]:
|
||||
cute.copy(tma_c, tCsC[(None, c_buffer)], tCgC_tma[(None, 0, 0, 0, 0, 0, 0)])
|
||||
cute.arch.cp_async_bulk_commit_group()
|
||||
cute.arch.cp_async_bulk_wait_group(0, read=True)
|
||||
epilog_sync_bar.arrive_and_wait()
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
||||
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile, tmem_load_epi_atom):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
@@ -399,28 +464,12 @@ class FmhaV3StageCMulti:
|
||||
# Wait for MMA's PV[N-1] to commit before reading O.
|
||||
final_o_bar.arrive_and_wait()
|
||||
|
||||
# === Final O normalization: O *= 1/row_sum ===
|
||||
# DIAG: NO-OP TMEM round-trip test — load and store back without modifying
|
||||
# === Correction epilog: one-way TMEM → SMEM with normalize ===
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
|
||||
# SKIP the TMEM round-trip normalize entirely
|
||||
# Just use epilogue_tma_store to read raw PV from TMEM
|
||||
# The inv_row_sum normalization will be applied in Python for now
|
||||
|
||||
# Standard epilogue: TMEM → SMEM → GMEM via TMA store.
|
||||
# O in TMEM is now scaled by 1/row_sum.
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(
|
||||
pipeline.PipelineUserType.Consumer, self.num_acc_stage
|
||||
self._correction_epilog(
|
||||
pv_thr, tOtO0, inv_row_sum, sC, tCgC, tma_c, sfw_idx,
|
||||
tidx, warp_idx, epi_tile, tmem_load_epi_atom,
|
||||
)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
|
||||
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
|
||||
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
|
||||
0, const_expr(lambda x: x), (0, 0, 0),
|
||||
acc_cons_st, acc_pipe, c_pipe,
|
||||
)
|
||||
c_pipe.producer_tail()
|
||||
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem_ptr)
|
||||
@@ -444,10 +493,6 @@ def test():
|
||||
attn = torch.softmax(attn_raw, dim=-1)
|
||||
ref = attn @ v.float()
|
||||
|
||||
# Compute unnormalized softmax and row_sum for Python-side normalize
|
||||
attn_unnorm = torch.exp(attn_raw - attn_raw.max(dim=-1, keepdim=True).values)
|
||||
row_sum_unnorm = attn_unnorm.sum(dim=-1, keepdim=True)
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
@@ -465,26 +510,17 @@ def test():
|
||||
torch.cuda.synchronize()
|
||||
|
||||
out = c[:, :, 0].float()
|
||||
|
||||
# Python-side normalize: out is raw P@V (unnormalized)
|
||||
# Divide by row_sum to get the correct softmax output
|
||||
out_normalized = out / row_sum_unnorm
|
||||
cos_raw = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), (attn_unnorm @ v.float()).flatten().unsqueeze(0)
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
cos_norm = torch.nn.functional.cosine_similarity(
|
||||
out_normalized.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
).item()
|
||||
max_abs = (out_normalized - ref).abs().max().item()
|
||||
max_abs = (out - ref).abs().max().item()
|
||||
n_tiles = n // 128
|
||||
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles):', flush=True)
|
||||
print(f' Raw PV (no normalize) vs unnorm ref: cos {cos_raw:.6f}', flush=True)
|
||||
print(f' After Python normalize vs softmax ref: cos {cos_norm:.6f} max_abs {max_abs:.4f} '
|
||||
f'{"PASS" if cos_norm >= 0.99 else "FAIL"}', flush=True)
|
||||
if cos_norm < 0.99:
|
||||
print(f' out_normalized[0,:4]={out_normalized[0,:4].tolist()}')
|
||||
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): '
|
||||
f'cos {cos:.6f} max_abs {max_abs:.4f} '
|
||||
f'{"PASS" if cos >= 0.99 else "FAIL"}')
|
||||
if cos < 0.99:
|
||||
print(f' out[0,:4]={out[0,:4].tolist()}')
|
||||
print(f' ref[0,:4]={ref[0,:4].tolist()}')
|
||||
print(f' row_sum_unnorm[:4]={row_sum_unnorm[:4,0].tolist()}')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user