Files
nvfp4-megamoe-kernel/tests/archive/unit_test_fmha_v3_per_row.py
biondizzle c395b279d2 Clean up tests: archive superseded files, keep only essential unit tests
Kept in tests/unit/:
- test_fmha_v3.py (stages A+B)
- test_fmha_v3_diag.py (identity softmax, n=128+256)
- test_fmha_v3_stage_c.py (real softmax, n=128 cos 0.999998)
- layertest.py + cudagraph_test.py (required for every change)
- infrastructure: cache, custom_op, cutedsl, router, fp4, fused, interleave

Archived: 19 superseded unit tests + 10 root-level scratch files
Root level: only fmha_v3_stage_c_example7.py remains (now in unit/)
2026-05-22 20:25:27 +00:00

588 lines
36 KiB
Python

"""
FMHA v3 + Stage C: QK -> online softmax -> PV with KV-tile interleaving.
Stage C: row_max, exp2, O rescale, row_sum, final normalization.
FMHA pattern P store preserved from Stage B.
"""
import math
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
HEAD_DIM = 64
class FmhaV3Softmax:
def __init__(self, s_k: int = 128):
self.s_k = s_k
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
# P occupies [tmem_p0_offset, tmem_p0_offset + p_cols_fp32)
# S occupies [0, qk_mma_tiler[1]) = [0, 128)
# O must NOT overlap P. Place O after max(S end, P end), aligned to 32.
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32 # 32 + 64 = 96
s_cols = self.qk_mma_tiler[1] # 128
o_after = max(s_cols, p_end) # 128
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
self.tmem_vec_offset = 0 # Reuse S region for per-row inv_row_sum vector # align to 32 = 128
self.tmem_vec_offset = 0 # Reuse S region (free after softmax loop)
o_cols = find_tmem_tensor_col_offset(tOtO) # footprint of O
total = self.tmem_o0_offset + o_cols
# Must be multiple of 32 AND power of 2
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e))
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# # s_k hardcoded # BROKEN in @cute.jit
# FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
pv_done_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
vec_handoff_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# --- PV read view (for MMA only, NOT for softmax store) ---
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q,tAgQ[(None,qh.count)],tAsQ[(None,qh.index)],tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles,unroll=1):
kh = kvp.acquire_and_advance(pk)
cute.copy(tma_k,tBgK[(None,kh.count)],tBsK[(None,kh.index)],tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
cute.copy(tma_v,tVgV[(None,vh.count)],tVsV[(None,vh.index)],tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(n_kv_tiles):
kh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ,mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit(); kh.release()
softmax_done_bar.arrive_and_wait()
vh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0,mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,vh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
vh.release()
pv_done_bar.arrive()
acc_pipe.producer_commit(acc_st); acc_st.advance()
acc_pipe.producer_tail(acc_st)
# ===================== EPILOGUE WARPS (STAGE C: ONLINE SOFTMAX) =====================
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# --- S load (QK C-fragment) ---
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# --- P store (QK C-fragment composition, FMHA pattern) ---
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# --- Vector TMEM (per-row row_sum storage, FMHA pattern) ---
# composition(tStS.layout, (128, 2)) = 2 FP32 columns per logical row
# vec[0] = row_sum (final, after loop), vec[1] = unused
# Reuses S TMEM region (offset 0), free after softmax loop writes
tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2)))
tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec_offset, tStS_vec_layout)
tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2)))
tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout)
tmem_store_vec_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec)
thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(sfw_idx)
tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec)
tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec)
tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_vec_atom, tStS_vec)
thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(sfw_idx)
tTMEM_LOAD_VECtS = thr_tmem_load_vec.partition_S(tStS_vec)
tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec)
# --- C6: O TMEM load/store for rescale (correction_rescale pattern) ---
corr_tile_size = 16
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
o_tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
o_tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
o_tiled_tmem_load = tcgen05.make_tmem_copy(o_tmem_load_atom, tOtO_i)
o_tiled_tmem_store = tcgen05.make_tmem_copy(o_tmem_store_atom, tOtO_i)
o_thr_load = o_tiled_tmem_load.get_slice(sfw_idx)
o_thr_store = o_tiled_tmem_store.get_slice(sfw_idx)
tTMEM_LOADtO = o_thr_load.partition_S(tOtO_i)
tTMEM_LOADcO = o_thr_load.partition_D(tOcO_i)
tTMEM_STOREtO = o_thr_store.partition_D(tOtO_i)
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
# --- C2: Per-QK-fragment-row state (persist across KV tiles) ---
# The QK TMEM load fragment is logically 4 rows x 32 columns for each
# softmax thread. The old scalar row_max/row_sum reduced across all
# 4 rows and therefore produced a row_sum around 4.0. Keep one
# online-softmax state per local QK row.
qk_frg_cnt = 4
qk_frg_tile = cute.size(tTMEM_LOADcS) // qk_frg_cnt
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(qk_frg_tile))
qk_row0 = tTMEM_LOADcS_frg[0, 0][0]
qk_row1 = tTMEM_LOADcS_frg[0, 1][0]
qk_row2 = tTMEM_LOADcS_frg[0, 2][0]
qk_row3 = tTMEM_LOADcS_frg[0, 3][0]
row_max0 = -cutlass.Float32.inf
row_max1 = -cutlass.Float32.inf
row_max2 = -cutlass.Float32.inf
row_max3 = -cutlass.Float32.inf
row_sum0 = cutlass.Float32(0.0)
row_sum1 = cutlass.Float32(0.0)
row_sum2 = cutlass.Float32(0.0)
row_sum3 = cutlass.Float32(0.0)
# --- C3: QK scale = 1/sqrt(HEAD_DIM) * log2(e) for exp2 ---
scale = self.scale_softmax_log2
# =============================================================
# Per-KV-tile online softmax loop
# =============================================================
for kt in range(n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# Load S from TMEM (FP32, QK C-fragment layout). Because the
# vector buffer reuses the S columns, all softmax threads must
# finish this load before any thread writes vector data.
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
vec_handoff_bar.arrive_and_wait()
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
# --- C4: Compute tile_max independently for each local QK row ---
old_row_max0 = row_max0
old_row_max1 = row_max1
old_row_max2 = row_max2
old_row_max3 = row_max3
row_max0 = tTMEM_LOADrS_frg[None, 0].load().reduce(cute.ReductionOp.MAX, row_max0, 0)
row_max1 = tTMEM_LOADrS_frg[None, 1].load().reduce(cute.ReductionOp.MAX, row_max1, 0)
row_max2 = tTMEM_LOADrS_frg[None, 2].load().reduce(cute.ReductionOp.MAX, row_max2, 0)
row_max3 = tTMEM_LOADrS_frg[None, 3].load().reduce(cute.ReductionOp.MAX, row_max3, 0)
row_max0_safe = row_max0
row_max1_safe = row_max1
row_max2_safe = row_max2
row_max3_safe = row_max3
if row_max0 == -cutlass.Float32.inf:
row_max0_safe = cutlass.Float32(0.0)
if row_max1 == -cutlass.Float32.inf:
row_max1_safe = cutlass.Float32(0.0)
if row_max2 == -cutlass.Float32.inf:
row_max2_safe = cutlass.Float32(0.0)
if row_max3 == -cutlass.Float32.inf:
row_max3_safe = cutlass.Float32(0.0)
# --- C5: Per-row O-rescale factors for the already-accumulated O ---
acc_scale0 = cute.math.exp2(scale * (old_row_max0 - row_max0_safe), fastmath=True)
acc_scale1 = cute.math.exp2(scale * (old_row_max1 - row_max1_safe), fastmath=True)
acc_scale2 = cute.math.exp2(scale * (old_row_max2 - row_max2_safe), fastmath=True)
acc_scale3 = cute.math.exp2(scale * (old_row_max3 - row_max3_safe), fastmath=True)
# --- C6: Rescale O in TMEM using a row-indexed vector handoff ---
# Store per-QK-row acc_scale into vec[row, 0], then read vec[pv_row, 0]
# from the PV/O partition. This is the CUTLASS-style vector bridge,
# but folded into the same four softmax warps, so it needs an
# explicit warpgroup barrier between store and load.
if kt > 0:
pv_done_bar.arrive_and_wait()
thr_vs0 = tiled_tmem_store_vec.get_slice(qk_row0)
tVStore0 = thr_vs0.partition_D(tStS_vec)
tVStoreSrc0 = thr_vs0.partition_S(tScS_vec)
rVec0 = cute.make_rmem_tensor(tVStoreSrc0.shape, self.qk_acc_dtype)
rVec0[0] = acc_scale0
rVec0[1] = row_max0_safe
cute.copy(tiled_tmem_store_vec, rVec0, tVStore0)
thr_vs1 = tiled_tmem_store_vec.get_slice(qk_row1)
tVStore1 = thr_vs1.partition_D(tStS_vec)
tVStoreSrc1 = thr_vs1.partition_S(tScS_vec)
rVec1 = cute.make_rmem_tensor(tVStoreSrc1.shape, self.qk_acc_dtype)
rVec1[0] = acc_scale1
rVec1[1] = row_max1_safe
cute.copy(tiled_tmem_store_vec, rVec1, tVStore1)
thr_vs2 = tiled_tmem_store_vec.get_slice(qk_row2)
tVStore2 = thr_vs2.partition_D(tStS_vec)
tVStoreSrc2 = thr_vs2.partition_S(tScS_vec)
rVec2 = cute.make_rmem_tensor(tVStoreSrc2.shape, self.qk_acc_dtype)
rVec2[0] = acc_scale2
rVec2[1] = row_max2_safe
cute.copy(tiled_tmem_store_vec, rVec2, tVStore2)
thr_vs3 = tiled_tmem_store_vec.get_slice(qk_row3)
tVStore3 = thr_vs3.partition_D(tStS_vec)
tVStoreSrc3 = thr_vs3.partition_S(tScS_vec)
rVec3 = cute.make_rmem_tensor(tVStoreSrc3.shape, self.qk_acc_dtype)
rVec3[0] = acc_scale3
rVec3[1] = row_max3_safe
cute.copy(tiled_tmem_store_vec, rVec3, tVStore3)
cute.arch.fence_view_async_tmem_store()
vec_handoff_bar.arrive_and_wait()
pv_row = tTMEM_LOADcO[0][0]
thr_vl = tiled_tmem_load_vec.get_slice(pv_row)
tVLoad = thr_vl.partition_S(tStS_vec)
tVLoadDst = thr_vl.partition_D(tScS_vec)
rVecPV = cute.make_rmem_tensor(tVLoadDst.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load_vec, tVLoad, rVecPV)
cute.arch.fence_view_async_tmem_load()
acc_scale_pv = rVecPV[0]
tTMrO = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
for i in range(o_col_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]))
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout)
tTMEM_STOREtO_i = cute.make_tensor(tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * acc_scale_pv
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Rescale the four online row sums.
row_sum0 = row_sum0 * acc_scale0
row_sum1 = row_sum1 * acc_scale1
row_sum2 = row_sum2 * acc_scale2
row_sum3 = row_sum3 * acc_scale3
# --- C7: Compute P = exp2((S - row_max[row]) * scale), per row ---
minus_row_max_scale0 = (cutlass.Float32(0.0) - row_max0_safe) * scale
minus_row_max_scale1 = (cutlass.Float32(0.0) - row_max1_safe) * scale
minus_row_max_scale2 = (cutlass.Float32(0.0) - row_max2_safe) * scale
minus_row_max_scale3 = (cutlass.Float32(0.0) - row_max3_safe) * scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 0] = tTMEM_LOADrS_frg[k, 0] * scale + minus_row_max_scale0
tTMEM_LOADrS_frg[k, 0] = cute.math.exp2(tTMEM_LOADrS_frg[k, 0], fastmath=True)
s_vec0 = tTMEM_LOADrS_frg[None, 0].load()
rP_bf16_frg[None, 0].store(s_vec0.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 1] = tTMEM_LOADrS_frg[k, 1] * scale + minus_row_max_scale1
tTMEM_LOADrS_frg[k, 1] = cute.math.exp2(tTMEM_LOADrS_frg[k, 1], fastmath=True)
s_vec1 = tTMEM_LOADrS_frg[None, 1].load()
rP_bf16_frg[None, 1].store(s_vec1.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 2] = tTMEM_LOADrS_frg[k, 2] * scale + minus_row_max_scale2
tTMEM_LOADrS_frg[k, 2] = cute.math.exp2(tTMEM_LOADrS_frg[k, 2], fastmath=True)
s_vec2 = tTMEM_LOADrS_frg[None, 2].load()
rP_bf16_frg[None, 2].store(s_vec2.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 3] = tTMEM_LOADrS_frg[k, 3] * scale + minus_row_max_scale3
tTMEM_LOADrS_frg[k, 3] = cute.math.exp2(tTMEM_LOADrS_frg[k, 3], fastmath=True)
s_vec3 = tTMEM_LOADrS_frg[None, 3].load()
rP_bf16_frg[None, 3].store(s_vec3.to(self.q_dtype))
# Store P to TMEM.
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# --- C8: Row sum accumulation, independently for each local QK row ---
tile_sum0 = tTMEM_LOADrS_frg[None, 0].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
tile_sum1 = tTMEM_LOADrS_frg[None, 1].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
tile_sum2 = tTMEM_LOADrS_frg[None, 2].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
tile_sum3 = tTMEM_LOADrS_frg[None, 3].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
row_sum0 = row_sum0 + tile_sum0
row_sum1 = row_sum1 + tile_sum1
row_sum2 = row_sum2 + tile_sum2
row_sum3 = row_sum3 + tile_sum3
# --- C9: Final normalization via row-indexed TMEM vector ---
# Wait for the final PV MMA to finish producing O.
pv_done_bar.arrive_and_wait()
# Publish final row_sum per QK row into vec[row, 0].
thr_vs0 = tiled_tmem_store_vec.get_slice(qk_row0)
tVStore0 = thr_vs0.partition_D(tStS_vec)
tVStoreSrc0 = thr_vs0.partition_S(tScS_vec)
rVec0 = cute.make_rmem_tensor(tVStoreSrc0.shape, self.qk_acc_dtype)
rVec0[0] = row_sum0
rVec0[1] = row_max0
cute.copy(tiled_tmem_store_vec, rVec0, tVStore0)
thr_vs1 = tiled_tmem_store_vec.get_slice(qk_row1)
tVStore1 = thr_vs1.partition_D(tStS_vec)
tVStoreSrc1 = thr_vs1.partition_S(tScS_vec)
rVec1 = cute.make_rmem_tensor(tVStoreSrc1.shape, self.qk_acc_dtype)
rVec1[0] = row_sum1
rVec1[1] = row_max1
cute.copy(tiled_tmem_store_vec, rVec1, tVStore1)
thr_vs2 = tiled_tmem_store_vec.get_slice(qk_row2)
tVStore2 = thr_vs2.partition_D(tStS_vec)
tVStoreSrc2 = thr_vs2.partition_S(tScS_vec)
rVec2 = cute.make_rmem_tensor(tVStoreSrc2.shape, self.qk_acc_dtype)
rVec2[0] = row_sum2
rVec2[1] = row_max2
cute.copy(tiled_tmem_store_vec, rVec2, tVStore2)
thr_vs3 = tiled_tmem_store_vec.get_slice(qk_row3)
tVStore3 = thr_vs3.partition_D(tStS_vec)
tVStoreSrc3 = thr_vs3.partition_S(tScS_vec)
rVec3 = cute.make_rmem_tensor(tVStoreSrc3.shape, self.qk_acc_dtype)
rVec3[0] = row_sum3
rVec3[1] = row_max3
cute.copy(tiled_tmem_store_vec, rVec3, tVStore3)
cute.arch.fence_view_async_tmem_store()
vec_handoff_bar.arrive_and_wait()
# Read the correct row_sum for this PV/O row and normalize O.
pv_row_final = tTMEM_LOADcO[0][0]
thr_vl_final = tiled_tmem_load_vec.get_slice(pv_row_final)
tVLoadFinal = thr_vl_final.partition_S(tStS_vec)
tVLoadFinalDst = thr_vl_final.partition_D(tScS_vec)
rVecFinal = cute.make_rmem_tensor(tVLoadFinalDst.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load_vec, tVLoadFinal, rVecFinal)
cute.arch.fence_view_async_tmem_load()
inv_row_sum = cutlass.Float32(1.0) / rVecFinal[0]
tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
for i in range(o_col_tiles):
tTMrO_i_ = tTMrO_final[None, i]
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO_final.shape[0]))
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Now O in TMEM is normalized. Use standard epilogue_tma_store with identity.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0,
const_expr(lambda x: x),
(0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
import math
torch.manual_seed(42)
for n in [128, 256, 384]:
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device="cuda")
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device="cuda")
v = torch.randn(n, hd, dtype=torch.bfloat16, device="cuda")
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda")
qf = q[:,:,0].float(); kf = k[:,:,0].float()
attn = qf @ kf.T / math.sqrt(hd)
ref = torch.softmax(attn, dim=-1) @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaV3Softmax(s_k=n)
print(f"n={n}: Compiling...", flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f"n={n}: tmem: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} vec={kernel.tmem_vec_offset} alloc={kernel.num_tmem_alloc_cols}", flush=True)
print(f"n={n}: Running...", flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True)
if __name__ == "__main__":
test()