FMHA v3: per-row min test + explicit loop replacements

- test_fmha_v3_per_row_min.py: minimal per-row test (no C6/C9, no barriers)
  Still hangs — likely CuTe DSL issue with logical_divide + explicit loops
- Replaced .load().reduce() on sliced tensors with explicit loops
- Very long compilation times suggest CuTe DSL is struggling

Key conclusion: per-row fix requires correction warp group.
The 6-warp code cant bridge 4 QK rows to 1 PV row per thread.
Need 128 correction threads (1 per output row) reading TMEM vector.
This commit is contained in:
2026-05-22 07:29:04 +00:00
parent 791bdc53a0
commit 7e1ba2b525

View File

@@ -0,0 +1,465 @@
"""
FMHA v3 + Stage C: QK -> online softmax -> PV with KV-tile interleaving.
Stage C: row_max, exp2, O rescale, row_sum, final normalization.
FMHA pattern P store preserved from Stage B.
"""
import math
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
HEAD_DIM = 64
class FmhaV3Softmax:
def __init__(self, s_k: int = 128):
self.s_k = s_k
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
# P occupies [tmem_p0_offset, tmem_p0_offset + p_cols_fp32)
# S occupies [0, qk_mma_tiler[1]) = [0, 128)
# O must NOT overlap P. Place O after max(S end, P end), aligned to 32.
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32 # 32 + 64 = 96
s_cols = self.qk_mma_tiler[1] # 128
o_after = max(s_cols, p_end) # 128
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
self.tmem_vec_offset = 0 # Reuse S region for per-row inv_row_sum vector # align to 32 = 128
self.tmem_vec_offset = 0 # Reuse S region (free after softmax loop)
o_cols = find_tmem_tensor_col_offset(tOtO) # footprint of O
total = self.tmem_o0_offset + o_cols
# Must be multiple of 32 AND power of 2
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
self.scale_softmax_log2 = Float32(1.0 / math.sqrt(HEAD_DIM) * math.log2(math.e))
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# # s_k hardcoded # BROKEN in @cute.jit
# FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, self.s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
pv_done_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
vec_handoff_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# --- PV read view (for MMA only, NOT for softmax store) ---
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q,tAgQ[(None,qh.count)],tAsQ[(None,qh.index)],tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles,unroll=1):
kh = kvp.acquire_and_advance(pk)
cute.copy(tma_k,tBgK[(None,kh.count)],tBsK[(None,kh.index)],tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
cute.copy(tma_v,tVgV[(None,vh.count)],tVsV[(None,vh.index)],tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(n_kv_tiles):
kh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ,mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit(); kh.release()
# softmax_done_bar skipped
vh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0,mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,vh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
vh.release()
# pv_done_bar skipped
acc_pipe.producer_commit(acc_st); acc_st.advance()
acc_pipe.producer_tail(acc_st)
# ===================== EPILOGUE WARPS (STAGE C: ONLINE SOFTMAX) =====================
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# --- S load (QK C-fragment) ---
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# --- P store (QK C-fragment composition, FMHA pattern) ---
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# --- Vector TMEM (per-row row_sum storage, FMHA pattern) ---
# composition(tStS.layout, (128, 2)) = 2 FP32 columns per logical row
# vec[0] = row_sum (final, after loop), vec[1] = unused
# Reuses S TMEM region (offset 0), free after softmax loop writes
tStS_vec_layout = cute.composition(tStS.layout, cute.make_layout((128, 2)))
tStS_vec = cute.make_tensor(tStS.iterator + self.tmem_vec_offset, tStS_vec_layout)
tScS_vec_layout = cute.composition(tScS.layout, cute.make_layout((128, 2)))
tScS_vec = cute.make_tensor(tScS.iterator, tScS_vec_layout)
tmem_store_vec_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_store_vec = tcgen05.make_tmem_copy(tmem_store_vec_atom, tStS_vec)
thr_tmem_store_vec = tiled_tmem_store_vec.get_slice(sfw_idx)
tTMEM_STORE_VECtS = thr_tmem_store_vec.partition_D(tStS_vec)
tTMEM_STORE_VECcS = thr_tmem_store_vec.partition_S(tScS_vec)
tmem_load_vec_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(2)), self.qk_acc_dtype)
tiled_tmem_load_vec = tcgen05.make_tmem_copy(tmem_load_vec_atom, tStS_vec)
thr_tmem_load_vec = tiled_tmem_load_vec.get_slice(sfw_idx)
tTMEM_LOAD_VECtS = thr_tmem_load_vec.partition_S(tStS_vec)
tTMEM_LOAD_VECcS = thr_tmem_load_vec.partition_D(tScS_vec)
# --- C6: O TMEM load/store for rescale (correction_rescale pattern) ---
corr_tile_size = 16
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
o_tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
o_tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
o_tiled_tmem_load = tcgen05.make_tmem_copy(o_tmem_load_atom, tOtO_i)
o_tiled_tmem_store = tcgen05.make_tmem_copy(o_tmem_store_atom, tOtO_i)
o_thr_load = o_tiled_tmem_load.get_slice(sfw_idx)
o_thr_store = o_tiled_tmem_store.get_slice(sfw_idx)
tTMEM_LOADtO = o_thr_load.partition_S(tOtO_i)
tTMEM_LOADcO = o_thr_load.partition_D(tOcO_i)
tTMEM_STOREtO = o_thr_store.partition_D(tOtO_i)
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
# --- C2: Per-QK-fragment-row state (persist across KV tiles) ---
# The QK TMEM load fragment is logically 4 rows x 32 columns for each
# softmax thread. The old scalar row_max/row_sum reduced across all
# 4 rows and therefore produced a row_sum around 4.0. Keep one
# online-softmax state per local QK row.
qk_frg_cnt = 4
qk_frg_tile = cute.size(tTMEM_LOADcS) // qk_frg_cnt
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(qk_frg_tile))
qk_row0 = tTMEM_LOADcS_frg[0, 0][0]
qk_row1 = tTMEM_LOADcS_frg[0, 1][0]
qk_row2 = tTMEM_LOADcS_frg[0, 2][0]
qk_row3 = tTMEM_LOADcS_frg[0, 3][0]
row_max0 = -cutlass.Float32.inf
row_max1 = -cutlass.Float32.inf
row_max2 = -cutlass.Float32.inf
row_max3 = -cutlass.Float32.inf
row_sum0 = cutlass.Float32(0.0)
row_sum1 = cutlass.Float32(0.0)
row_sum2 = cutlass.Float32(0.0)
row_sum3 = cutlass.Float32(0.0)
# --- C3: QK scale = 1/sqrt(HEAD_DIM) * log2(e) for exp2 ---
scale = self.scale_softmax_log2
# =============================================================
# Per-KV-tile online softmax loop
# =============================================================
for kt in range(n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# Load S from TMEM (FP32, QK C-fragment layout). Because the
# vector buffer reuses the S columns, all softmax threads must
# finish this load before any thread writes vector data.
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# vec_handoff_bar skipped
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
# --- C4: Compute tile_max independently for each local QK row ---
old_row_max0 = row_max0
old_row_max1 = row_max1
old_row_max2 = row_max2
old_row_max3 = row_max3
# Per-row max: explicit loop (avoid .load().reduce on sliced tensor)
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max0 = cutlass.Float32.max(row_max0, tTMEM_LOADrS_frg[k, 0])
row_max1 = cutlass.Float32.max(row_max1, tTMEM_LOADrS_frg[k, 1])
row_max2 = cutlass.Float32.max(row_max2, tTMEM_LOADrS_frg[k, 2])
row_max3 = cutlass.Float32.max(row_max3, tTMEM_LOADrS_frg[k, 3])
row_max0_safe = row_max0
row_max1_safe = row_max1
row_max2_safe = row_max2
row_max3_safe = row_max3
if row_max0 == -cutlass.Float32.inf:
row_max0_safe = cutlass.Float32(0.0)
if row_max1 == -cutlass.Float32.inf:
row_max1_safe = cutlass.Float32(0.0)
if row_max2 == -cutlass.Float32.inf:
row_max2_safe = cutlass.Float32(0.0)
if row_max3 == -cutlass.Float32.inf:
row_max3_safe = cutlass.Float32(0.0)
# --- C5: Per-row O-rescale factors for the already-accumulated O ---
acc_scale0 = cute.math.exp2(scale * (old_row_max0 - row_max0_safe), fastmath=True)
acc_scale1 = cute.math.exp2(scale * (old_row_max1 - row_max1_safe), fastmath=True)
acc_scale2 = cute.math.exp2(scale * (old_row_max2 - row_max2_safe), fastmath=True)
acc_scale3 = cute.math.exp2(scale * (old_row_max3 - row_max3_safe), fastmath=True)
# --- C6: SKIPPED (minimal test) ---
if kt > 0:
pass # pv_done_bar skipped
else:
pass # pv_done_bar skipped
# --- C7: Compute P = exp2((S - row_max[row]) * scale), per row ---
minus_row_max_scale0 = (cutlass.Float32(0.0) - row_max0_safe) * scale
minus_row_max_scale1 = (cutlass.Float32(0.0) - row_max1_safe) * scale
minus_row_max_scale2 = (cutlass.Float32(0.0) - row_max2_safe) * scale
minus_row_max_scale3 = (cutlass.Float32(0.0) - row_max3_safe) * scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 0] = tTMEM_LOADrS_frg[k, 0] * scale + minus_row_max_scale0
tTMEM_LOADrS_frg[k, 0] = cute.math.exp2(tTMEM_LOADrS_frg[k, 0], fastmath=True)
s_vec0 = tTMEM_LOADrS_frg[None, 0].load()
rP_bf16_frg[None, 0].store(s_vec0.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 1] = tTMEM_LOADrS_frg[k, 1] * scale + minus_row_max_scale1
tTMEM_LOADrS_frg[k, 1] = cute.math.exp2(tTMEM_LOADrS_frg[k, 1], fastmath=True)
s_vec1 = tTMEM_LOADrS_frg[None, 1].load()
rP_bf16_frg[None, 1].store(s_vec1.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 2] = tTMEM_LOADrS_frg[k, 2] * scale + minus_row_max_scale2
tTMEM_LOADrS_frg[k, 2] = cute.math.exp2(tTMEM_LOADrS_frg[k, 2], fastmath=True)
s_vec2 = tTMEM_LOADrS_frg[None, 2].load()
rP_bf16_frg[None, 2].store(s_vec2.to(self.q_dtype))
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, 3] = tTMEM_LOADrS_frg[k, 3] * scale + minus_row_max_scale3
tTMEM_LOADrS_frg[k, 3] = cute.math.exp2(tTMEM_LOADrS_frg[k, 3], fastmath=True)
s_vec3 = tTMEM_LOADrS_frg[None, 3].load()
rP_bf16_frg[None, 3].store(s_vec3.to(self.q_dtype))
# Store P to TMEM.
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# softmax_done_bar skipped
# --- C8: Row sum accumulation, independently for each local QK row ---
# Per-row sum: explicit loop
tile_sum0 = cutlass.Float32(0.0)
tile_sum1 = cutlass.Float32(0.0)
tile_sum2 = cutlass.Float32(0.0)
tile_sum3 = cutlass.Float32(0.0)
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tile_sum0 = tile_sum0 + tTMEM_LOADrS_frg[k, 0]
tile_sum1 = tile_sum1 + tTMEM_LOADrS_frg[k, 1]
tile_sum2 = tile_sum2 + tTMEM_LOADrS_frg[k, 2]
tile_sum3 = tile_sum3 + tTMEM_LOADrS_frg[k, 3]
row_sum0 = row_sum0 + tile_sum0
row_sum1 = row_sum1 + tile_sum1
row_sum2 = row_sum2 + tile_sum2
row_sum3 = row_sum3 + tile_sum3
# --- C9: SKIPPED (minimal test) ---
# pv_done_bar skipped
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0,
const_expr(lambda x: x),
(0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
import math
torch.manual_seed(42)
for n in [128, 256, 384]:
m, hd = 128, HEAD_DIM
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device="cuda")
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device="cuda")
v = torch.randn(n, hd, dtype=torch.bfloat16, device="cuda")
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device="cuda")
qf = q[:,:,0].float(); kf = k[:,:,0].float()
attn = qf @ kf.T / math.sqrt(hd)
ref = torch.softmax(attn, dim=-1) @ v.float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = FmhaV3Softmax(s_k=n)
print(f"n={n}: Compiling...", flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f"n={n}: tmem: s0={kernel.tmem_s0_offset} p0={kernel.tmem_p0_offset} o0={kernel.tmem_o0_offset} vec={kernel.tmem_vec_offset} alloc={kernel.num_tmem_alloc_cols}", flush=True)
print(f"n={n}: Running...", flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print(f"FMHA softmax n={n}: cosine {cos:.6f} max_err {max_err:.6f} {'PASS' if cos >= 0.999 else 'FAIL'}", flush=True)
if __name__ == "__main__":
test()