Stage C: integrate example3 multi-tile fixes into unit test

- Combined K+V barrier (one acquire per kt, kvh.count == kt)
- O rescale for kt > 0 (online softmax O correction)
- final_o_bar sync (MMA signals before producer_tail)
- s_k as constructor param (compile-time for V layout)
- kv_tx_bytes covers both K and V transfers
- Test covers n=128, 256, 512, 1024
This commit is contained in:
2026-05-22 16:39:45 +00:00
parent 793e3243d5
commit c9fe26a5fc
2 changed files with 624 additions and 100 deletions

View File

@@ -0,0 +1,456 @@
"""
FMHA v3 Stage-C Multi-Tile (combined K+V barrier).
Replaces the interleaved K-then-V acquires with a single acquire per kt that
loads K and V onto the SAME barrier slot. tx_count is sized for K+V together.
With one acquire per tile, the pipeline `count` returned by acquire_and_advance
goes 0, 1, 2, ... and matches the KV tile index directly — no interleaving
problem, and no need for Python ints or integer-division gymnastics in the
TMA coordinate. kvh.count stays a first-class pipeline state value, which is
the form cute.copy accepts.
Changes vs the single-tile file:
1. s_k MUST equal actual n. v_fmha layout uses s_k as the V sequence dim.
2. kv pipeline carries combined K+V per stage:
- tx_count = K_bytes + V_bytes
- producer: one acquire per kt, K and V copies share kvh.barrier
- consumer: one wait per kt, kvh.index used for both sK and sV reads
- release happens after PV (no separate K-early-release path)
Bonus: this also fixes the unused-SMEM-slot quirk where the original kernel
only ever used sK[0] and sV[1] because of the interleaved count.
3. O rescale between KV tiles re-enabled (gated on kt > 0). Lives in softmax
body BEFORE softmax_done_bar.arrive(), so MMA's PV[kt] reads a rescaled O.
4. Explicit MMA→softmax sync before the final normalize.
final_o_bar is a NamedBarrier with 32 MMA + 128 softmax threads. MMA
.arrive() between acc_pipe.producer_commit and producer_tail; softmax
.arrive_and_wait() before reading O. Without this, softmax can race MMA's PV[N-1] and divide a
partially-accumulated O by row_sum. The single-tile test masked the race
because the timing happened to work.
Alternative if combined-barrier ever bites: keep the interleaved pipeline and
index GMEM by `kh.count // 2` / `vh.count // 2`. Requires CuTeDSL to support
Int32 floor-division in a TMA coordinate. Not used here.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
HEAD_DIM = 64
class FmhaV3StageCMulti:
def __init__(self, s_k=128, scale_softmax=None):
# s_k MUST equal actual sequence length n.
self.s_k = s_k
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(HEAD_DIM)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
# Combined barrier: tx_count covers BOTH K and V transfers per acquire.
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
# Combined K+V pipeline: each stage carries BOTH K and V loaded together.
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
# Final-O sync: MMA arrives between producer_commit and producer_tail;
# softmax arrives_and_waits before reading O for the final normalize.
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
# One acquire per kt; K and V both target kvh.barrier. kvh.count == kt.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
# Both transfers decrement the same barrier's tx_count.
# kvh.count is a pipeline-state Int32 (the form cute.copy accepts).
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ===== MMA warp =====
# One wait per kt; same slot index used for both K (QK) and V (PV).
# Release happens AFTER PV — combined slot stays held across QK+PV.
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
acc_pipe.producer_commit(acc_st); acc_st.advance()
# Signal softmax FIRST so it can run normalize + epilogue. Then
# wait for the epilogue's consumer-release in producer_tail.
# Reverse order deadlocks: producer_tail blocks waiting for
# consumer release; softmax blocks at final_o_bar waiting for
# MMA arrive; the epilogue (which does the release) is gated
# behind softmax's final_o_bar wait. Cycle.
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# ===== SOFTMAX + EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# S load
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# O rescale / normalize path
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
corr_tile_size = 16
tOtO_i_layout = cute.composition(tOtO.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
tmem_store_o_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.acc_dtype)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
thr_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
tTMEM_LOAD_OtO = thr_load_o.partition_S(tOtO_i)
tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i)
tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i)
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
for kt in range(n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# Load S[kt]
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Pass 1: update row_max
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
# acc_scale used for both row_sum rescale and O rescale.
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
# store BF16 P through the FP32-backed register bridge.
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max_scale
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# O rescale for kt > 0. Reads O written by MMA's PV[kt-1];
# visibility is provided by s_cons.wait_and_advance above
# (acquires on MMA's S[kt] commit, which orders PV[kt-1] before).
if kt > 0:
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + i * corr_tile_size,
tTMEM_LOAD_OtO.layout,
)
tTMEM_STORE_O_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + i * corr_tile_size,
tTMEM_STORE_OtO.layout,
)
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# Wait for MMA's last PV to commit before reading O for normalize.
# Without this barrier softmax can race MMA's PV[N-1].
final_o_bar.arrive_and_wait()
# Final O = O / row_sum
inv_row_sum = Float32(1.0) / row_sum
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + i * corr_tile_size,
tTMEM_LOAD_OtO.layout,
)
tTMEM_STORE_O_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + i * corr_tile_size,
tTMEM_STORE_OtO.layout,
)
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
cute.arch.fence_view_async_tmem_store()
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
for n in [128, 256, 512, 1024]:
torch.manual_seed(42)
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Each n requires its own compiled kernel (s_k is compile-time).
kernel = FmhaV3StageCMulti(s_k=n)
print(f'n={n}: Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f'n={n}: tmem s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} '
f'o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols} '
f'kv_tx_bytes={kernel.kv_tx_bytes}', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_abs = (out - ref).abs().max().item()
n_tiles = n // 128
print(f'FMHA Stage-C Multi n={n} ({n_tiles} kv tiles): '
f'cos {cos:.6f} max_abs {max_abs:.4f} '
f'{"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' out[0,:4]={out[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
if __name__ == '__main__':
test()

View File

@@ -1,8 +1,26 @@
"""
FMHA v3 Stage-C: Real softmax + O normalization.
Builds on the 12w identity-softmax test by replacing identity softmax with
online softmax (row_max, exp2 scaling, P store) and adding O normalization
by row_sum before the epilogue writes to GMEM.
FMHA v3 Stage-C: Real online softmax with full multi-tile support.
Integrated from:
- test_fmha_v3.py (Stage A+B: QK→identity softmax→PV)
- test_fmha_v3_stage_c_full.py (Stage C: single-tile real softmax)
- fmha_v3_stage_c_example3.py (Multi-tile: combined K+V barrier, O rescale, final_o_bar)
Architecture (6-warp, 192 threads):
Warps 0-3: Softmax + Epilogue (row_max, exp2, P store, O rescale, O normalize, TMA store)
Warp 4: MMA (QK GEMM, PV GEMM)
Warp 5: TMA (Q, K, V load)
Multi-tile key changes vs single-tile:
1. Combined K+V barrier: one acquire_and_advance per kt, both K and V share kvh.barrier.
kvh.count == kt naturally — no interleaving, no Python int in TMA coordinates.
2. kv_tx_bytes covers BOTH K and V transfers per stage.
3. V FMHA layout uses s_k as the sequence dimension (compile-time, not hardcoded 128).
4. MMA: same slot index (kvh.index) for both K (QK) and V (PV). Release after PV.
5. O rescale for kt > 0: exp2((old_max - new_max) * scale_log2) on O in TMEM.
6. final_o_bar: MMA arrives between producer_commit and producer_tail;
softmax arrives_and_wait before reading O for final normalize.
7. s_k is a constructor parameter — each sequence length requires its own compiled kernel.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
@@ -15,8 +33,10 @@ import math
HEAD_DIM = 64
class FmhaV3StageC:
def __init__(self, s_k=128, scale_softmax=None):
# s_k MUST equal actual sequence length n (compile-time constant for V layout).
self.s_k = s_k
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
@@ -49,24 +69,24 @@ class FmhaV3StageC:
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
# P occupies [tmem_p0_offset, tmem_p0_offset + p_cols_fp32)
# S occupies [0, qk_mma_tiler[1]) = [0, 128)
# O must NOT overlap P. Place O after max(S end, P end), aligned to 32.
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32 # 32 + 64 = 96
s_cols = self.qk_mma_tiler[1] # 128
o_after = max(s_cols, p_end) # 128
self.tmem_o0_offset = ((o_after + 31) // 32) * 32 # align to 32 = 128
o_cols = find_tmem_tensor_col_offset(tOtO) # footprint of O
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
# Must be multiple of 32 AND power of 2
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
# Combined barrier: tx_count covers BOTH K and V transfers per acquire.
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream):
@@ -74,6 +94,7 @@ class FmhaV3StageC:
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major
# s_k is compile-time — must match actual sequence length n.
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
@@ -111,9 +132,17 @@ class FmhaV3StageC:
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
# Combined K+V pipeline: each stage carries BOTH K and V loaded together.
# One acquire per kt → kvh.count == kt (pipeline state value, accepted by cute.copy).
# No interleaving problem. No Python int in TMA coordinates.
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
# Final-O sync: MMA arrives between producer_commit and producer_tail;
# softmax arrives_and_waits before reading O for the final normalize.
# This prevents softmax from racing MMA's PV[N-1] and dividing a
# partially-accumulated O by row_sum.
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
@@ -162,22 +191,25 @@ class FmhaV3StageC:
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD
# ===== TMA LOAD warp =====
# One acquire per kt; K and V both target kvh.barrier. kvh.count == kt.
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q,tAgQ[(None,qh.count)],tAsQ[(None,qh.index)],tma_bar_ptr=qh.barrier)
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles,unroll=1):
kh = kvp.acquire_and_advance(pk)
cute.copy(tma_k,tBgK[(None,kh.count)],tBsK[(None,kh.index)],tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
cute.copy(tma_v,tVgV[(None,vh.count)],tVsV[(None,vh.index)],tma_bar_ptr=vh.barrier)
for kt in cutlass.range(n_kv_tiles, unroll=1):
kvh = kvp.acquire_and_advance(pk)
# Both transfers decrement the same barrier's tx_count.
# kvh.count is a pipeline-state Int32 (the form cute.copy accepts).
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# MMA
# ===== MMA warp =====
# One wait per kt; same slot index used for both K (QK) and V (PV).
# Release happens AFTER PV — combined slot stays held across QK+PV.
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
@@ -185,46 +217,48 @@ class FmhaV3StageC:
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(n_kv_tiles):
kh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ,mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kh.index)], tStS0)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit(); kh.release()
sh.commit()
softmax_done_bar.arrive_and_wait()
vh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0,mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,vh.index)], tOtO0)
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
vh.release()
kvh.release()
acc_pipe.producer_commit(acc_st); acc_st.advance()
# Signal softmax FIRST so it can run normalize + epilogue. Then
# wait for the epilogue's consumer-release in producer_tail.
# Reverse order deadlocks: producer_tail blocks waiting for
# consumer release; softmax blocks at final_o_bar waiting for
# MMA arrive; the epilogue (which does the release) is gated
# behind softmax's final_o_bar wait. Cycle.
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# SOFTMAX + EPILOGUE (warps 0-3)
# Step 1: Real online softmax: load S, compute row_max, exp2 scale, store P
# Step 2: After all KV tiles, normalize O in TMEM by row_sum
# Step 3: Epilogue writes normalized O from TMEM to GMEM
# ===== SOFTMAX + EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# --- S load (QK C-fragment layout) ---
# S load (QK C-fragment layout)
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# --- P store (QK C-fragment layout composition, FMHA pattern) ---
# P store (QK C-fragment layout composition, FMHA pattern)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
@@ -236,8 +270,7 @@ class FmhaV3StageC:
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# --- Online softmax loop ---
# --- O rescale + normalize setup (before softmax loop) ---
# O rescale / normalize path
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
corr_tile_size = 16
@@ -255,18 +288,21 @@ class FmhaV3StageC:
tTMEM_LOAD_OcO = thr_load_o.partition_D(tOcO_i)
tTMEM_STORE_OtO = thr_store_o.partition_D(tOtO_i)
row_max = -Float32.inf; row_sum = Float32(0.0)
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
for kt in range(n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# Load S from TMEM
# Load S[kt]
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# Compute row_max from S (element-wise, before exp2)
# Pass 1: update row_max
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
@@ -276,15 +312,18 @@ class FmhaV3StageC:
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf: row_max_safe = Float32(0.0)
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
# Scale existing row_sum
# acc_scale for both row_sum rescale and O rescale.
acc_scale_ = scale_log2 * (old_row_max - row_max_safe)
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf: acc_scale = Float32(0.0)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# Second pass: compute P and accumulate row_sum
# Pass 2: P = exp2((S - new_max) * log2), accumulate row_sum,
# store BF16 P through the FP32-backed register bridge.
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max_scale = (Float32(0.0) - row_max_safe) * scale_log2
@@ -301,38 +340,54 @@ class FmhaV3StageC:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# O rescale: if kt > 0, rescale O in TMEM by exp2((old_max - new_max) * scale_log2)
# TEMPORARILY DISABLED FOR DEBUGGING
# if kt > 0:
# corr_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=True)
# o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
# for i in range(o_col_tiles):
# tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout)
# tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout)
# tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
# cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
# for k in cutlass.range(cute.size(tTMrO), vectorize=True):
# tTMrO[k] = tTMrO[k] * corr_scale
# cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
# cute.arch.fence_view_async_tmem_store()
# O rescale for kt > 0. Reads O written by MMA's PV[kt-1];
# visibility is provided by s_cons.wait_and_advance above
# (acquires on MMA's S[kt] commit, which orders PV[kt-1] before).
if kt > 0:
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + i * corr_tile_size,
tTMEM_LOAD_OtO.layout,
)
tTMEM_STORE_O_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + i * corr_tile_size,
tTMEM_STORE_OtO.layout,
)
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# --- Final O normalization by row_sum ---
# Wait for MMA's last PV to commit before reading O for normalize.
# Without this barrier softmax can race MMA's PV[N-1].
final_o_bar.arrive_and_wait()
# Final O = O / row_sum
inv_row_sum = Float32(1.0) / row_sum
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
for i in range(o_col_tiles):
tTMEM_LOAD_O_i = cute.make_tensor(tTMEM_LOAD_OtO.iterator + i * corr_tile_size, tTMEM_LOAD_OtO.layout)
tTMEM_STORE_O_i = cute.make_tensor(tTMEM_STORE_OtO.iterator + i * corr_tile_size, tTMEM_STORE_OtO.layout)
tTMEM_LOAD_O_i = cute.make_tensor(
tTMEM_LOAD_OtO.iterator + i * corr_tile_size,
tTMEM_LOAD_OtO.layout,
)
tTMEM_STORE_O_i = cute.make_tensor(
tTMEM_STORE_OtO.iterator + i * corr_tile_size,
tTMEM_STORE_OtO.layout,
)
tTMrO = cute.make_rmem_tensor(tTMEM_LOAD_OcO.shape, self.acc_dtype)
cute.copy(tiled_tmem_load_o, tTMEM_LOAD_O_i, tTMrO)
cute.arch.fence_view_async_tmem_load()
for k in cutlass.range(cute.size(tTMrO), vectorize=True):
tTMrO[k] = tTMrO[k] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO, tTMEM_STORE_O_i)
cute.arch.fence_view_async_tmem_store()
# --- Epilogue: write O from TMEM to GMEM ---
# Epilogue: TMEM -> SMEM -> GMEM via TMA store
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
@@ -345,39 +400,52 @@ class FmhaV3StageC:
def test():
torch.manual_seed(42)
# Test single tile (n=128) and multi-tile (n=256)
for n in [128, 256]:
for seed in [42]:
torch.manual_seed(seed)
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kf = k[:,:,0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaV3StageC(s_k=128)
if seed == 42:
print(f'seed={seed}: Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
if seed == 42:
print(f'tmem_offsets: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols}', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'FMHA Stage-C n={n} seed={seed}: cosine {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' out[0,:4]={out[0,:4].tolist()} ref[0,:4]={ref[0,:4].tolist()}')
for n in [128, 256, 512, 1024]:
torch.manual_seed(42)
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# PyTorch reference
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn = qf @ kf.T * scale
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Each n requires its own compiled kernel (s_k is compile-time).
kernel = FmhaV3StageC(s_k=n)
print(f'n={n}: Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f'n={n}: tmem s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} '
f'o0={kernel.tmem_o0_offset} alloc={kernel.num_tmem_alloc_cols} '
f'kv_tx_bytes={kernel.kv_tx_bytes}', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
).item()
max_abs = (out - ref).abs().max().item()
n_tiles = n // 128
print(f'FMHA Stage-C n={n} ({n_tiles} kv tiles): '
f'cos {cos:.6f} max_abs {max_abs:.4f} '
f'{"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' out[0,:4]={out[0,:4].tolist()}')
print(f' ref[0,:4]={ref[0,:4].tolist()}')
if __name__ == '__main__':
test()