Clean v2: real softmax P, no O TMEM modify, standard epilogue. Baseline for custom epilogue work.
This commit is contained in:
@@ -1,6 +1,14 @@
|
||||
"""
|
||||
FMHA v3 Stage-C Multi-Tile — Real Softmax.
|
||||
Built on the WORKING identity diag, adding real softmax step by step.
|
||||
FMHA v3 Stage-C — Real Softmax, NO O rescale/normalize in TMEM.
|
||||
|
||||
Strategy: Skip O rescale and TMEM-based normalize (TMEM copy of O corrupts data).
|
||||
For single-tile (n=128), this gives correct unnormalized output (cos 0.999998).
|
||||
For multi-tile, the O is not rescaled (missing exp2(old_max - new_max) and 1/row_sum).
|
||||
|
||||
The CUTLASS reference applies O rescale via correction_rescale (TMEM read-modify-write)
|
||||
and the final 1/row_sum via correction_epilog (applied during GMEM write, NOT TMEM modify).
|
||||
Our TMEM copy of O doesn't work — likely a CuTeDSL version issue or layout mismatch.
|
||||
Next step: implement correction_epilog that applies 1/row_sum during GMEM write.
|
||||
"""
|
||||
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
@@ -156,7 +164,7 @@ class FmhaV3RealSoftmax:
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
# ===== TMA LOAD warp (matching working diag) =====
|
||||
# ===== TMA LOAD warp =====
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset(); qh = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
||||
@@ -193,12 +201,12 @@ class FmhaV3RealSoftmax:
|
||||
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
|
||||
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
|
||||
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
cute.arch.fence_async_shared()
|
||||
kvh.release()
|
||||
acc_pipe.producer_commit(acc_st); acc_st.advance()
|
||||
acc_pipe.producer_tail(acc_st)
|
||||
|
||||
# ===== SOFTMAX warps — REAL SOFTMAX =====
|
||||
# ===== SOFTMAX warps — REAL SOFTMAX (P only, no O normalize in TMEM) =====
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(self.num_tmem_alloc_cols)
|
||||
tmem.wait_for_alloc()
|
||||
@@ -226,28 +234,6 @@ class FmhaV3RealSoftmax:
|
||||
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
|
||||
tTMEM_STOREcP = thr_store.partition_S(tScP)
|
||||
|
||||
# O normalize setup: use the SAME base pointer as the epilogue
|
||||
# The epilogue reads O from tmem_ptr + tmem_o0_offset.
|
||||
# We must use the same base to access the correct TMEM columns.
|
||||
tCtO_norm = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tOtO.layout)
|
||||
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
|
||||
tOcO = pv_thr.partition_C(cO)
|
||||
# Sub-tile the O layout for the normalize copy
|
||||
corr_tile_size = 16
|
||||
tOtO_i_layout = cute.composition(tCtO_norm.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
|
||||
tOtO_norm_i = cute.make_tensor(tCtO_norm.iterator, tOtO_i_layout)
|
||||
tOcO_norm_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
|
||||
tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
|
||||
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_norm_i)
|
||||
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_norm_i)
|
||||
thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
|
||||
thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
|
||||
tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_norm_i)
|
||||
tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_norm_i)
|
||||
tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_norm_i)
|
||||
|
||||
row_max = -Float32.inf
|
||||
row_sum = Float32(0.0)
|
||||
scale_log2 = Float32(self.scale_softmax_log2)
|
||||
@@ -273,18 +259,13 @@ class FmhaV3RealSoftmax:
|
||||
if row_max == -cutlass.Float32.inf:
|
||||
row_max_safe = Float32(0.0)
|
||||
|
||||
# O rescale: exp2(old_max - new_max) — row_max already in scaled domain
|
||||
# acc_scale: exp2(old_max - new_max) for O rescale
|
||||
acc_scale_ = old_row_max - row_max_safe
|
||||
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
|
||||
if old_row_max == -cutlass.Float32.inf:
|
||||
acc_scale = Float32(0.0)
|
||||
row_sum *= acc_scale
|
||||
|
||||
# O rescale: DISABLED for NO-OP test
|
||||
# if kt > 0:
|
||||
# n_corr = HEAD_DIM // corr_tile_size
|
||||
# for ci in range(n_corr):
|
||||
# ...
|
||||
# Pass 2: P = exp2(S * scale_log2 - row_max), accumulate row_sum
|
||||
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
|
||||
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
|
||||
@@ -303,31 +284,8 @@ class FmhaV3RealSoftmax:
|
||||
si_handle.release()
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# Final O normalization: O = O / row_sum
|
||||
if row_sum != Float32(0.0):
|
||||
inv_row_sum = Float32(1.0) / row_sum
|
||||
n_corr = HEAD_DIM // corr_tile_size
|
||||
tTMrO = cute.make_rmem_tensor(
|
||||
(tTMEM_LOAD_OcO.shape, n_corr), self.acc_dtype
|
||||
)
|
||||
for ci in range(n_corr):
|
||||
tTMrO_ci_ = tTMrO[None, ci]
|
||||
tTMrO_ci_layout = cute.composition(
|
||||
tTMrO_ci_.layout, cute.make_layout(tTMrO.shape[0])
|
||||
)
|
||||
tTMrO_ci = cute.make_tensor(tTMrO_ci_.iterator, tTMrO_ci_layout)
|
||||
tTMEM_LOAD_OtO_ci = cute.make_tensor(
|
||||
tTMEM_LOAD_OtO.iterator + ci * corr_tile_size, tTMEM_LOAD_OtO.layout
|
||||
)
|
||||
tTMEM_STORE_OtO_ci = cute.make_tensor(
|
||||
tTMEM_STORE_OtO.iterator + ci * corr_tile_size, tTMEM_STORE_OtO.layout
|
||||
)
|
||||
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_OtO_ci, tTMrO_ci)
|
||||
for j in cutlass.range(cute.size(tTMrO_ci), vectorize=True):
|
||||
tTMrO_ci[j] = tTMrO_ci[j] * inv_row_sum
|
||||
cute.copy(tiled_tmem_store_o, tTMrO_ci, tTMEM_STORE_OtO_ci)
|
||||
|
||||
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
|
||||
# Epilogue: standard epilogue_tma_store
|
||||
# TODO: replace with custom epilogue that applies 1/row_sum
|
||||
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
|
||||
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
|
||||
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
|
||||
@@ -339,7 +297,7 @@ class FmhaV3RealSoftmax:
|
||||
|
||||
|
||||
def test():
|
||||
for n in [128, 256]:
|
||||
for n in [128, 256, 384, 512, 1024]:
|
||||
torch.manual_seed(42)
|
||||
m, hd = 128, HEAD_DIM
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
@@ -363,7 +321,7 @@ def test():
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = FmhaV3RealSoftmax(s_k=n)
|
||||
print(f'n={n}: Compiling... [REAL_SOFTMAX_v1]', flush=True)
|
||||
print(f'n={n}: Compiling... [REAL_SOFTMAX_v2_EPILOGUE]', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
@@ -374,7 +332,7 @@ def test():
|
||||
).item()
|
||||
max_abs = (out - ref).abs().max().item()
|
||||
n_tiles = n // 128
|
||||
print(f'FMHA Real Softmax n={n} ({n_tiles} tiles): '
|
||||
print(f'FMHA n={n} ({n_tiles} tiles): '
|
||||
f'cos {cos:.6f} max_abs {max_abs:.4f} '
|
||||
f'{"PASS" if cos >= 0.99 else "FAIL"}')
|
||||
if cos < 0.99:
|
||||
|
||||
Reference in New Issue
Block a user