FMHA v3 Stage-C full: 12-warp pipeline with real softmax + correction + epilogue

- Softmax warps (0-3): online row max, exp2 scaling, P store, vec broadcast
- Correction warps (4-7): online O rescale, final normalization, SMEM write
- MMA warp (8): QK->S, PV->O with proper pipeline chaining
- TMA warp (9): Q/K/V load
- Epilogue warp (10): TMA store O from SMEM to GMEM
- Empty warp (11): tmem dealloc mbar init
- Pipeline chain: mma_s -> softmax -> s_corr -> correction -> corr_epi -> epilogue
- Plus mma_corr -> correction for O rescale
- Reference test uses softmax(Q@K^T/sqrt(d))@V
This commit is contained in:
2026-05-22 09:18:56 +00:00
parent b81ed1924b
commit 208af3eadd
2 changed files with 450 additions and 587 deletions

View File

@@ -1,587 +0,0 @@
"""
FMHA v3 + Stage C: QK -> online softmax -> PV with KV-tile interleaving.
Stage C: row_max, exp2, O rescale, row_sum, final normalization.
FMHA pattern P store preserved from Stage B.
"""
import math
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
HEAD_DIM = 64
class FmhaV3Softmax:
def __init__(self, s_k: int = 128):
self.s_k = s_k
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
# P occupies [tmem_p0_offset, tmem_p0_offset + p_cols_fp32)
# S occupies [0, qk_mma_tiler[1]) = [0, 128)
# O must NOT overlap P. Place O after max(S end, P end), aligned to 32.
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32 # 32 + 64 = 96
s_cols = self.qk_mma_tiler[1] # 128
o_after = max(s_cols, p_end) # 128
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
self.tmem_vec_offset = 0 # Reuse S region for per-row inv_row_sum vector # align to 32 = 128
self.tmem_vec_offset = 0 # Reuse S region (free after softmax loop)
o_cols = find_tmem_tensor_col_offset(tOtO) # footprint of O
total = self.tmem_o0_offset + o_cols
# Must be multiple of 32 AND power of 2
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e))
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# # s_k hardcoded # BROKEN in @cute.jit
# FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
pv_done_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
vec_handoff_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# --- PV read view (for MMA only, NOT for softmax store) ---
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q,tAgQ[(None,qh.count)],tAsQ[(None,qh.index)],tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles,unroll=1):
kh = kvp.acquire_and_advance(pk)
cute.copy(tma_k,tBgK[(None,kh.count)],tBsK[(None,kh.index)],tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
cute.copy(tma_v,tVgV[(None,vh.count)],tVsV[(None,vh.index)],tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(n_kv_tiles):
kh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ,mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit(); kh.release()
softmax_done_bar.arrive_and_wait()
vh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0,mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,vh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
vh.release()
pv_done_bar.arrive()
acc_pipe.producer_commit(acc_st); acc_st.advance()
acc_pipe.producer_tail(acc_st)
# ===================== EPILOGUE WARPS (STAGE C: ONLINE SOFTMAX) =====================
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# --- S load (QK C-fragment) ---
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# --- P store (QK C-fragment composition, FMHA pattern) ---
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# --- Vector TMEM (per-row row_sum storage, FMHA pattern) ---
# composition(tStS.layout, (128, 2)) = 2 FP32 columns per logical row
# vec[0] = row_sum (final, after loop), vec[1] = unused
# Reuses S TMEM region (offset 0), free after softmax loop writes
tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2)))
tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec_offset, tStS_vec_layout)
tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2)))
tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout)
tmem_store_vec_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec)
thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(sfw_idx)
tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec)
tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec)
tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_vec_atom, tStS_vec)
thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(sfw_idx)
tTMEM_LOAD_VECtS = thr_tmem_load_vec.partition_S(tStS_vec)
tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec)
# --- C6: O TMEM load/store for rescale (correction_rescale pattern) ---
corr_tile_size = 16
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
o_tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
o_tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
o_tiled_tmem_load = tcgen05.make_tmem_copy(o_tmem_load_atom, tOtO_i)
o_tiled_tmem_store = tcgen05.make_tmem_copy(o_tmem_store_atom, tOtO_i)
o_thr_load = o_tiled_tmem_load.get_slice(sfw_idx)
o_thr_store = o_tiled_tmem_store.get_slice(sfw_idx)
tTMEM_LOADtO = o_thr_load.partition_S(tOtO_i)
tTMEM_LOADcO = o_thr_load.partition_D(tOcO_i)
tTMEM_STOREtO = o_thr_store.partition_D(tOtO_i)
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
# --- C2: Per-QK-fragment-row state (persist across KV tiles) ---
# The QK TMEM load fragment is logically 4 rows x 32 columns for each
# softmax thread. The old scalar row_max/row_sum reduced across all
# 4 rows and therefore produced a row_sum around 4.0. Keep one
# online-softmax state per local QK row.
qk_frg_cnt = 4
qk_frg_tile = cute.size(tTMEM_LOADcS) // qk_frg_cnt
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(qk_frg_tile))
qk_row0 = tTMEM_LOADcS_frg[0, 0][0]
qk_row1 = tTMEM_LOADcS_frg[0, 1][0]
qk_row2 = tTMEM_LOADcS_frg[0, 2][0]
qk_row3 = tTMEM_LOADcS_frg[0, 3][0]
row_max0 = -cutlass.Float32.inf
row_max1 = -cutlass.Float32.inf
row_max2 = -cutlass.Float32.inf
row_max3 = -cutlass.Float32.inf
row_sum0 = cutlass.Float32(0.0)
row_sum1 = cutlass.Float32(0.0)
row_sum2 = cutlass.Float32(0.0)
row_sum3 = cutlass.Float32(0.0)
# --- C3: QK scale = 1/sqrt(HEAD_DIM) * log2(e) for exp2 ---
scale = self.scale_softmax_log2
# =============================================================
# Per-KV-tile online softmax loop
# =============================================================
for kt in range(n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# Load S from TMEM (FP32, QK C-fragment layout). Because the
# vector buffer reuses the S columns, all softmax threads must
# finish this load before any thread writes vector data.
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
vec_handoff_bar.arrive_and_wait()
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
# --- C4: Compute tile_max independently for each local QK row ---
old_row_max0 = row_max0
old_row_max1 = row_max1
old_row_max2 = row_max2
old_row_max3 = row_max3
row_max0 = tTMEM_LOADrS_frg[None, 0].load().reduce(cute.ReductionOp.MAX, row_max0, 0)
row_max1 = tTMEM_LOADrS_frg[None, 1].load().reduce(cute.ReductionOp.MAX, row_max1, 0)
row_max2 = tTMEM_LOADrS_frg[None, 2].load().reduce(cute.ReductionOp.MAX, row_max2, 0)
row_max3 = tTMEM_LOADrS_frg[None, 3].load().reduce(cute.ReductionOp.MAX, row_max3, 0)
row_max0_safe = row_max0
row_max1_safe = row_max1
row_max2_safe = row_max2
row_max3_safe = row_max3
if row_max0 == -cutlass.Float32.inf:
row_max0_safe = cutlass.Float32(0.0)
if row_max1 == -cutlass.Float32.inf:
row_max1_safe = cutlass.Float32(0.0)
if row_max2 == -cutlass.Float32.inf:
row_max2_safe = cutlass.Float32(0.0)
if row_max3 == -cutlass.Float32.inf:
row_max3_safe = cutlass.Float32(0.0)
# --- C5: Per-row O-rescale factors for the already-accumulated O ---
acc_scale0 = cute.math.exp2(scale * (old_row_max0 - row_max0_safe), fastmath=True)
acc_scale1 = cute.math.exp2(scale * (old_row_max1 - row_max1_safe), fastmath=True)
acc_scale2 = cute.math.exp2(scale * (old_row_max2 - row_max2_safe), fastmath=True)
acc_scale3 = cute.math.exp2(scale * (old_row_max3 - row_max3_safe), fastmath=True)
# --- C6: Rescale O in TMEM using a row-indexed vector handoff ---
# Store per-QK-row acc_scale into vec[row, 0], then read vec[pv_row, 0]
# from the PV/O partition. This is the CUTLASS-style vector bridge,
# but folded into the same four softmax warps, so it needs an
# explicit warpgroup barrier between store and load.
if kt > 0:
pv_done_bar.arrive_and_wait()
thr_vs0 = tiled_tmem_store_vec.get_slice(qk_row0)
tVStore0 = thr_vs0.partition_D(tStS_vec)
tVStoreSrc0 = thr_vs0.partition_S(tScS_vec)
rVec0 = cute.make_rmem_tensor(tVStoreSrc0.shape, self.qk_acc_dtype)
rVec0[0] = acc_scale0
rVec0[1] = row_max0_safe
cute.copy(tiled_tmem_store_vec, rVec0, tVStore0)
thr_vs1 = tiled_tmem_store_vec.get_slice(qk_row1)
tVStore1 = thr_vs1.partition_D(tStS_vec)
tVStoreSrc1 = thr_vs1.partition_S(tScS_vec)
rVec1 = cute.make_rmem_tensor(tVStoreSrc1.shape, self.qk_acc_dtype)
rVec1[0] = acc_scale1
rVec1[1] = row_max1_safe
cute.copy(tiled_tmem_store_vec, rVec1, tVStore1)
thr_vs2 = tiled_tmem_store_vec.get_slice(qk_row2)
tVStore2 = thr_vs2.partition_D(tStS_vec)
tVStoreSrc2 = thr_vs2.partition_S(tScS_vec)
rVec2 = cute.make_rmem_tensor(tVStoreSrc2.shape, self.qk_acc_dtype)
rVec2[0] = acc_scale2
rVec2[1] = row_max2_safe
cute.copy(tiled_tmem_store_vec, rVec2, tVStore2)
thr_vs3 = tiled_tmem_store_vec.get_slice(qk_row3)
tVStore3 = thr_vs3.partition_D(tStS_vec)
tVStoreSrc3 = thr_vs3.partition_S(tScS_vec)
rVec3 = cute.make_rmem_tensor(tVStoreSrc3.shape, self.qk_acc_dtype)
rVec3[0] = acc_scale3
rVec3[1] = row_max3_safe
cute.copy(tiled_tmem_store_vec, rVec3, tVStore3)
cute.arch.fence_view_async_tmem_store()
vec_handoff_bar.arrive_and_wait()
pv_row = tTMEM_LOADcO[0][0]
thr_vl = tiled_tmem_load_vec.get_slice(pv_row)
tVLoad = thr_vl.partition_S(tStS_vec)
tVLoadDst = thr_vl.partition_D(tScS_vec)
rVecPV = cute.make_rmem_tensor(tVLoadDst.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load_vec, tVLoad, rVecPV)
cute.arch.fence_view_async_tmem_load()
acc_scale_pv = rVecPV[0]
tTMrO = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
for i in range(o_col_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]))
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout)
tTMEM_STOREtO_i = cute.make_tensor(tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * acc_scale_pv
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Rescale the four online row sums.
row_sum0 = row_sum0 * acc_scale0
row_sum1 = row_sum1 * acc_scale1
row_sum2 = row_sum2 * acc_scale2
row_sum3 = row_sum3 * acc_scale3
# --- C7: Compute P = exp2((S - row_max[row]) * scale), per row ---
minus_row_max_scale0 = (cutlass.Float32(0.0) - row_max0_safe) * scale
minus_row_max_scale1 = (cutlass.Float32(0.0) - row_max1_safe) * scale
minus_row_max_scale2 = (cutlass.Float32(0.0) - row_max2_safe) * scale
minus_row_max_scale3 = (cutlass.Float32(0.0) - row_max3_safe) * scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 0] = tTMEM_LOADrS_frg[k, 0] * scale + minus_row_max_scale0
tTMEM_LOADrS_frg[k, 0] = cute.math.exp2(tTMEM_LOADrS_frg[k, 0], fastmath=True)
s_vec0 = tTMEM_LOADrS_frg[None, 0].load()
rP_bf16_frg[None, 0].store(s_vec0.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 1] = tTMEM_LOADrS_frg[k, 1] * scale + minus_row_max_scale1
tTMEM_LOADrS_frg[k, 1] = cute.math.exp2(tTMEM_LOADrS_frg[k, 1], fastmath=True)
s_vec1 = tTMEM_LOADrS_frg[None, 1].load()
rP_bf16_frg[None, 1].store(s_vec1.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 2] = tTMEM_LOADrS_frg[k, 2] * scale + minus_row_max_scale2
tTMEM_LOADrS_frg[k, 2] = cute.math.exp2(tTMEM_LOADrS_frg[k, 2], fastmath=True)
s_vec2 = tTMEM_LOADrS_frg[None, 2].load()
rP_bf16_frg[None, 2].store(s_vec2.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 3] = tTMEM_LOADrS_frg[k, 3] * scale + minus_row_max_scale3
tTMEM_LOADrS_frg[k, 3] = cute.math.exp2(tTMEM_LOADrS_frg[k, 3], fastmath=True)
s_vec3 = tTMEM_LOADrS_frg[None, 3].load()
rP_bf16_frg[None, 3].store(s_vec3.to(self.q_dtype))
# Store P to TMEM.
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# --- C8: Row sum accumulation, independently for each local QK row ---
tile_sum0 = tTMEM_LOADrS_frg[None, 0].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
tile_sum1 = tTMEM_LOADrS_frg[None, 1].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
tile_sum2 = tTMEM_LOADrS_frg[None, 2].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
tile_sum3 = tTMEM_LOADrS_frg[None, 3].load().reduce(cute.ReductionOp.ADD, cutlass.Float32(0.0), 0)
row_sum0 = row_sum0 + tile_sum0
row_sum1 = row_sum1 + tile_sum1
row_sum2 = row_sum2 + tile_sum2
row_sum3 = row_sum3 + tile_sum3
# --- C9: Final normalization via row-indexed TMEM vector ---
# Wait for the final PV MMA to finish producing O.
pv_done_bar.arrive_and_wait()
# Publish final row_sum per QK row into vec[row, 0].
thr_vs0 = tiled_tmem_store_vec.get_slice(qk_row0)
tVStore0 = thr_vs0.partition_D(tStS_vec)
tVStoreSrc0 = thr_vs0.partition_S(tScS_vec)
rVec0 = cute.make_rmem_tensor(tVStoreSrc0.shape, self.qk_acc_dtype)
rVec0[0] = row_sum0
rVec0[1] = row_max0
cute.copy(tiled_tmem_store_vec, rVec0, tVStore0)
thr_vs1 = tiled_tmem_store_vec.get_slice(qk_row1)
tVStore1 = thr_vs1.partition_D(tStS_vec)
tVStoreSrc1 = thr_vs1.partition_S(tScS_vec)
rVec1 = cute.make_rmem_tensor(tVStoreSrc1.shape, self.qk_acc_dtype)
rVec1[0] = row_sum1
rVec1[1] = row_max1
cute.copy(tiled_tmem_store_vec, rVec1, tVStore1)
thr_vs2 = tiled_tmem_store_vec.get_slice(qk_row2)
tVStore2 = thr_vs2.partition_D(tStS_vec)
tVStoreSrc2 = thr_vs2.partition_S(tScS_vec)
rVec2 = cute.make_rmem_tensor(tVStoreSrc2.shape, self.qk_acc_dtype)
rVec2[0] = row_sum2
rVec2[1] = row_max2
cute.copy(tiled_tmem_store_vec, rVec2, tVStore2)
thr_vs3 = tiled_tmem_store_vec.get_slice(qk_row3)
tVStore3 = thr_vs3.partition_D(tStS_vec)
tVStoreSrc3 = thr_vs3.partition_S(tScS_vec)
rVec3 = cute.make_rmem_tensor(tVStoreSrc3.shape, self.qk_acc_dtype)
rVec3[0] = row_sum3
rVec3[1] = row_max3
cute.copy(tiled_tmem_store_vec, rVec3, tVStore3)
cute.arch.fence_view_async_tmem_store()
vec_handoff_bar.arrive_and_wait()
# Read the correct row_sum for this PV/O row and normalize O.
pv_row_final = tTMEM_LOADcO[0][0]
thr_vl_final = tiled_tmem_load_vec.get_slice(pv_row_final)
tVLoadFinal = thr_vl_final.partition_S(tStS_vec)
tVLoadFinalDst = thr_vl_final.partition_D(tScS_vec)
rVecFinal = cute.make_rmem_tensor(tVLoadFinalDst.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load_vec, tVLoadFinal, rVecFinal)
cute.arch.fence_view_async_tmem_load()
inv_row_sum = cutlass.Float32(1.0) / rVecFinal[0]
tTMrO_final = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
for i in range(o_col_tiles):
tTMrO_i_ = tTMrO_final[None, i]
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO_final.shape[0]))
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Now O in TMEM is normalized. Use standard epilogue_tma_store with identity.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0,
const_expr(lambda x: x),
(0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
import math
torch.manual_seed(42)
for n in [128, 256, 384]:
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device="cuda")
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device="cuda")
v = torch.randn(n, hd, dtype=torch.bfloat16, device="cuda")
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda")
qf = q[:,:,0].float(); kf = k[:,:,0].float()
attn = qf @ kf.T / math.sqrt(hd)
ref = torch.softmax(attn, dim=-1) @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaV3Softmax(s_k=n)
print(f"n={n}: Compiling...", flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f"n={n}: tmem: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} vec={kernel.tmem_vec_offset} alloc={kernel.num_tmem_alloc_cols}", flush=True)
print(f"n={n}: Running...", flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True)
if __name__ == "__main__":
test()

View File

@@ -0,0 +1,450 @@
"""
FMHA v3 Stage-C Full: Production Blackwell pipeline with real softmax + correction.
Architecture (12-warps, matches CUTLASS FMHA):
softmax warps 0-3 : S(TMEM) -> softmax -> P(TMEM), vec(TMEM)
correction warps 4-7 : vec(TMEM) + O(TMEM) -> corrected O(SMEM)
MMA warp 8 : QK and PV
TMA/load warp 9 : Q/K/V load
epilogue warp 10 : corrected O SMEM -> GMEM via TMA
empty warp 11 : tmem dealloc mbar init
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
HEAD_DIM = 64
class FmhaV3StageC:
def __init__(self, s_k=128, scale_softmax=None):
self.s_k = s_k
self.acc_dtype = Float32; self.qk_acc_dtype = Float32; self.pv_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
# 12-warp layout
self.softmax_warp_ids = (0, 1, 2, 3)
self.correction_warp_ids = (4, 5, 6, 7)
self.mma_warp_id = 8; self.tma_warp_id = 9
self.epilogue_warp_id = 10; self.empty_warp_id = 11
self.threads_per_cta = 32 * 12
# Pipeline stages
self.mma_softmax_stage = 1; self.softmax_corr_stage = 1
self.mma_corr_stage = 2; self.epi_stage = 2
# TMA stages
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
# Softmax scaling
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(HEAD_DIM)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
self.scale_output = 1.0
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_epilogue_smem_layout(self.o_dtype, self.c_layout, self.epi_tile, self.epi_stage)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_vec0_offset = 0; self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32; s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO); total = self.tmem_o0_offset + o_cols
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total: self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage * 2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2]
mma_s_bar: cute.struct.MemRange[cutlass.Int64, self.mma_softmax_stage * 2]
s_corr_bar: cute.struct.MemRange[cutlass.Int64, self.softmax_corr_stage * 2]
mma_corr_bar: cute.struct.MemRange[cutlass.Int64, self.mma_corr_stage * 2]
corr_epi_bar: cute.struct.MemRange[cutlass.Int64, self.epi_stage * 2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
def cg(n): return pipeline.CooperativeGroup(pipeline.Agent.Thread, n)
qp, qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(), num_stages=self.q_stage, producer_group=cg(1), consumer_group=cg(1), tx_count=self.q_tx_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
kvp, kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(), num_stages=self.kv_stage, producer_group=cg(1), consumer_group=cg(1), tx_count=self.kv_tx_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
mma_s_prod, mma_s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_s_bar.data_ptr(), num_stages=self.mma_softmax_stage, producer_group=cg(1), consumer_group=cg(32 * len(self.softmax_warp_ids)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
s_corr_prod, s_corr_cons = pipeline.PipelineAsync.create(barrier_storage=st.s_corr_bar.data_ptr(), num_stages=self.softmax_corr_stage, producer_group=cg(32 * len(self.softmax_warp_ids)), consumer_group=cg(32 * len(self.correction_warp_ids))).make_participants()
mma_corr_prod, mma_corr_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.mma_corr_bar.data_ptr(), num_stages=self.mma_corr_stage, producer_group=cg(1), consumer_group=cg(32 * len(self.correction_warp_ids)), cta_layout_vmnk=cl_vmnk, defer_sync=True).make_participants()
corr_epi_prod, corr_epi_cons = pipeline.PipelineAsync.create(barrier_storage=st.corr_epi_bar.data_ptr(), num_stages=self.epi_stage, producer_group=cg(32 * len(self.correction_warp_ids)), consumer_group=cg(32)).make_participants()
tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((*self.softmax_warp_ids, *self.correction_warp_ids, self.mma_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.softmax_warp_ids[0], is_two_cta=cute.size(qk_mma.thr_id.shape) == 2, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
if warp_idx == self.empty_warp_id:
cute.arch.mbarrier_init(st.tmem_dealloc, 32 * len((*self.softmax_warp_ids, *self.correction_warp_ids)))
cute.arch.mbarrier_init_fence()
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=q_smem_s.outer, byte_alignment=128, swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=k_smem_s.outer, byte_alignment=128, swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None, 0, None)), (None, None, None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0, None, None)), (None, None, None))
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0, None, None)), (None, None, None))
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ, 0, 3), cute.group_modes(tCgQ, 0, 3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK, 0, 3), cute.group_modes(tCgK, 0, 3))
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV, 0, 3), cute.group_modes(tCgV, 0, 3))
tAgQ = tAgQ[(None, 0, None, 0)]; tBgK = tBgK[(None, 0, None, 0)]; tVgV = tVgV[(None, 0, None, 0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK); tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, tOrP.layout)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, 1))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ==================== TMA WARP (9) ====================
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles, unroll=1):
kh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kh.count)], tBsK[(None, kh.index)], tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
cute.copy(tma_v, tVgV[(None, vh.count)], tVsV[(None, vh.index)], tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ==================== MMA WARP (8) ====================
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
for kt in range(n_kv_tiles):
# QK -> S
kh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = mma_s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None, None, kb, 0)], tCrK[(None, None, kb, kh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store(); sh.commit(); kh.release()
# PV -> O
vh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
oh = mma_corr_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None, None, kb)], tCrV[(None, None, kb, vh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store(); oh.commit(); vh.release()
mma_s_prod.tail(); mma_corr_prod.tail()
cute.arch.relinquish_tmem_alloc_permit()
cute.arch.mbarrier_wait(st.tmem_dealloc, 0)
tmem_ptr = cute.arch.retrieve_tmem_ptr(self.qk_acc_dtype, alignment=16, ptr_to_buffer_holding_addr=st.holding)
cute.arch.dealloc_tmem(tmem_ptr, Int32(self.num_tmem_alloc_cols))
# ==================== SOFTMAX WARPS (0-3) ====================
if warp_idx < len(self.softmax_warp_ids):
tmem.allocate(self.num_tmem_alloc_cols); tmem.wait_for_alloc()
sfw_idx = tidx % (32 * len(self.softmax_warp_ids))
# S load setup
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store setup (QK C-fragment layout composition, FMHA pattern)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tTMEM_STOREcP = thr_store.partition_S(cute.make_tensor(tScS.iterator, tScP_layout))
# Vec store setup
tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2)))
tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout)
tmem_store_vec_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec)
thr_store_vec = tiled_tmem_store_vec.get_slice(sfw_idx)
tTMEM_STORE_VECtS = thr_store_vec.partition_D(tStS_vec)
tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2)))
tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout)
tTMEM_STORE_VECcS = thr_store_vec.partition_S(tScS_vec)
row_max = -Float32.inf; row_sum = Float32(0.0)
vec_handle = s_corr_prod.acquire_and_advance()
scale_log2 = Float32(self.scale_softmax_log2)
for kt in range(n_kv_tiles):
si_handle = mma_s_cons.wait_and_advance()
# Load S from TMEM
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Row max
old_row_max = row_max
row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)
# Vec = [old_max, new_max]
tTMEM_STORE_VECrS = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype)
tTMEM_STORE_VECrS[0] = old_row_max; tTMEM_STORE_VECrS[1] = row_max_safe
cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS)
cute.arch.fence_view_async_tmem_store()
vec_handle.commit()
# P = exp2((S - new_max) * scale_log2) via register bridge
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2
# Scale existing row_sum
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0)
row_sum *= acc_scale
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
vec_handle = s_corr_prod.acquire_and_advance()
# Final vec = [row_sum, row_max] for correction epilog
tTMEM_STORE_VECrS = cute.make_rmem_tensor(tTMEM_STORE_VECcS.shape, self.qk_acc_dtype)
tTMEM_STORE_VECrS[0] = row_sum; tTMEM_STORE_VECrS[1] = row_max
cute.copy(tiled_tmem_store_vec, tTMEM_STORE_VECrS, tTMEM_STORE_VECtS)
cute.arch.fence_view_async_tmem_store()
vec_handle.commit()
s_corr_prod.acquire() # balance final pipe step
s_corr_prod.tail()
cute.arch.mbarrier_arrive(st.tmem_dealloc)
tmem.relinquish_alloc_permit()
# ==================== CORRECTION WARPS (4-7) ====================
if warp_idx >= len(self.softmax_warp_ids) and warp_idx < len(self.softmax_warp_ids) + len(self.correction_warp_ids):
tmem.wait_for_alloc()
corr_idx = tidx % (32 * len(self.correction_warp_ids))
# Vec load
tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2)))
tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec0_offset, tStS_vec_layout)
tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_vec_atom, tStS_vec)
thr_load_vec = tiled_tmem_load_vec.get_slice(corr_idx)
tTMEM_LOAD_VECtS = thr_load_vec.partition_S(tStS_vec)
tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2)))
tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout)
tTMEM_LOAD_VECcS = thr_load_vec.partition_D(tScS_vec)
# O load/store for correction_rescale (matching CUTLASS pattern)
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
corr_tile_size = 16
tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.pv_acc_dtype)
tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.pv_acc_dtype)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_load_o = tiled_tmem_load_o.get_slice(corr_idx)
thr_store_o = tiled_tmem_store_o.get_slice(corr_idx)
tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i)
tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i)
tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i)
scale_log2 = Float32(self.scale_softmax_log2)
# First vec has no previous O to rescale
first_vec = s_corr_cons.wait_and_advance(); first_vec.release()
for kt in range(n_kv_tiles - 1):
vec = s_corr_cons.wait_and_advance()
# Read vec = [old_max, new_max]
tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS, tTMEM_LOAD_VECrS)
cute.arch.fence_view_async_tmem_load()
old_max = tTMEM_LOAD_VECrS[0]; new_max = tTMEM_LOAD_VECrS[1]
# scale = exp2((old_max - new_max) * scale_log2)
corr_scale = cute.math.exp2(scale_log2 * (old_max - new_max), fastmath=True)
# Wait for O from MMA, rescale O in TMEM
o_handle = mma_corr_cons.wait_and_advance()
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout)
tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout)
tTMrO_i_ = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.pv_acc_dtype)
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMEM_LOAD_OcO.shape[0]))
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO_i)
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * corr_scale
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STORE_O_i)
cute.arch.fence_view_async_tmem_store()
o_handle.release(); vec.release()
# Final: read [row_sum, row_max], normalize O, write to SMEM
final_vec = s_corr_cons.wait_and_advance()
tTMEM_LOAD_VECrS = cute.make_rmem_tensor(tTMEM_LOAD_VECcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load_vec, tTMEM_LOAD_VECtS, tTMEM_LOAD_VECrS)
cute.arch.fence_view_async_tmem_load()
row_sum = tTMEM_LOAD_VECrS[0]; row_max = tTMEM_LOAD_VECrS[1]
final_vec.release()
final_o = mma_corr_cons.wait_and_advance()
epi_handle = corr_epi_prod.acquire_and_advance()
# Correction epilog: load O from TMEM, normalize, convert to BF16, write SMEM
# Following CUTLASS correction_epilog pattern
corr_tile_size_epi = 32 * 8 // self.o_dtype.width
tOsO = pv_thr.partition_C(sC)
tOcO_epi = pv_thr.partition_C(cO)
tOtO_i_epi = cute.logical_divide(tOtO, cute.make_layout((128, corr_tile_size_epi)))
tOcO_i_epi = cute.logical_divide(tOcO_epi, cute.make_layout((128, corr_tile_size_epi)))
tOsO_i = cute.logical_divide(tOsO, cute.make_layout((128, corr_tile_size_epi)))
epi_subtile = (self.epi_tile[0], corr_tile_size_epi)
tmem_copy_atom = utils.sm100.get_tmem_load_op(self.pv_mma_tiler, self.c_layout, self.o_dtype, self.pv_acc_dtype, epi_subtile, use_2cta_instrs=False)
tiled_tmem_load_epi = tcgen05.make_tmem_copy(tmem_copy_atom, tOtO_i_epi[(None, None), 0])
thr_tmem_load_epi = tiled_tmem_load_epi.get_slice(corr_idx)
smem_copy_atom = utils.sm100.get_smem_store_op(self.c_layout, self.o_dtype, self.pv_acc_dtype, tiled_tmem_load_epi)
tiled_smem_store = cute.make_tiled_copy_D(smem_copy_atom, tiled_tmem_load_epi)
tTMEM_LOAD_EPItO = thr_tmem_load_epi.partition_S(tOtO_i_epi[(None, None), None])
tTMEM_LOAD_EPIdS = thr_tmem_load_epi.partition_D(tOsO_i[(None, None), None])
tTMEM_LOAD_EPIdO = thr_tmem_load_epi.partition_D(tOcO_i_epi[(None, None), None])
inv_row_sum = Float32(1.0) / row_sum
for i in range(self.pv_mma_tiler[1] // corr_tile_size_epi):
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_EPIdO[None, 0, 0, i].shape, self.pv_acc_dtype)
cute.copy(tiled_tmem_load_epi, tTMEM_LOAD_EPItO[None, 0, 0, i], tTMrO)
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * inv_row_sum
tSMrO = cute.make_rmem_tensor(tTMrO.shape, self.o_dtype)
tSMrO.store(tTMrO.load().to(self.o_dtype))
cute.copy(tiled_smem_store, tSMrO, tTMEM_LOAD_EPIdS[None, 0, 0, i])
cute.arch.fence_proxy("async.shared", space="cta")
final_o.release()
epi_handle.commit()
cute.arch.mbarrier_arrive(st.tmem_dealloc)
# ==================== EPILOGUE WARP (10) ====================
if warp_idx == self.epilogue_warp_id:
epi_handle = corr_epi_cons.wait_and_advance()
# TMA store O from SMEM to GMEM
cute.copy(tma_c, sC, tCgC[(None, 0)])
cute.arch.cp_async_bulk_commit_group()
cute.arch.cp_async_bulk_wait_group(0, read=True)
epi_handle.release()
def test():
torch.manual_seed(42)
for n in [128]:
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# Reference: softmax(Q @ K^T) @ V
qf = q[:,:,0].float(); kf = k[:,:,0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaV3StageC(s_k=n)
print(f'n={n}: Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f'n={n}: tmem_offsets: s0={kernel.tmem_s0_offset} vec0={kernel.tmem_vec0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols}', flush=True)
print(f'n={n}: Running...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'FMHA Stage-C n={n}: cosine {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' out[0,:4]={out[0,:4].tolist()} ref[0,:4]={ref[0,:4].tolist()}')
if __name__ == '__main__':
test()