Stage C: online softmax kernel (WIP) - test_fmha_v3_softmax.py

- C1: Real softmax reference (torch.softmax, not identity)
- C2: Per-thread row_max/row_sum registers
- C3: QK scale folded (1/sqrt(d) * log2(e))
- C4: Row max via .reduce(MAX)
- C5: Rescale factor (exp2(old_max - new_max))
- C6: O rescale in TMEM (correction_rescale pattern)
- C7: Real exp2 for P computation
- C8: Row sum via packed f32x2 reduction
- C9: Final normalization (1/row_sum in epilogue)
- Dynamic s_k for V FMHA reconstruction
- fastmath=False for correctness first
This commit is contained in:
2026-05-21 17:10:58 +00:00
parent 20564425ec
commit df04ba40ee

View File

@@ -0,0 +1,364 @@
"""
FMHA v3 + Stage C: QK -> online softmax -> PV with KV-tile interleaving.
Stage C: row_max, exp2, O rescale, row_sum, final normalization.
FMHA pattern P store preserved from Stage B.
"""
import math
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
HEAD_DIM = 64
class FmhaV3Softmax:
def __init__(self):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), HEAD_DIM, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
# P occupies [tmem_p0_offset, tmem_p0_offset + p_cols_fp32)
# S occupies [0, qk_mma_tiler[1]) = [0, 128)
# O must NOT overlap P. Place O after max(S end, P end), aligned to 32.
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32 # 32 + 64 = 96
s_cols = self.qk_mma_tiler[1] # 128
o_after = max(s_cols, p_end) # 128
self.tmem_o0_offset = ((o_after + 31) // 32) * 32 # align to 32 = 128
o_cols = find_tmem_tensor_col_offset(tOtO) # footprint of O
total = self.tmem_o0_offset + o_cols
# Must be multiple of 32 AND power of 2
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
s_k = cute.size(v, mode=[0])
# FMHA-style V: reconstruct as (HEAD_DIM, s_k, 1) MN-major
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(HEAD_DIM, s_k, 1),
stride=(1, HEAD_DIM, HEAD_DIM * s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# --- PV read view (for MMA only, NOT for softmax store) ---
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_as, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q,tAgQ[(None,qh.count)],tAsQ[(None,qh.index)],tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(n_kv_tiles,unroll=1):
kh = kvp.acquire_and_advance(pk)
cute.copy(tma_k,tBgK[(None,kh.count)],tBsK[(None,kh.index)],tma_bar_ptr=kh.barrier)
pk = cutlass.Boolean(1)
vh = kvp.acquire_and_advance(pk)
cute.copy(tma_v,tVgV[(None,vh.count)],tVsV[(None,vh.index)],tma_bar_ptr=vh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# MMA
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(n_kv_tiles):
kh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ,mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit(); kh.release()
softmax_done_bar.arrive_and_wait()
vh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0,mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,vh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
vh.release()
acc_pipe.producer_commit(acc_st); acc_st.advance()
acc_pipe.producer_tail(acc_st)
# ===================== EPILOGUE WARPS (STAGE C: ONLINE SOFTMAX) =====================
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# --- S load (QK C-fragment) ---
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# --- P store (QK C-fragment composition, FMHA pattern) ---
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# --- C6: O TMEM load/store for rescale (correction_rescale pattern) ---
corr_tile_size = 16
cO = cute.make_identity_tensor((self.pv_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
o_tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
o_tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)), self.qk_acc_dtype)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
o_tiled_tmem_load = tcgen05.make_tmem_copy(o_tmem_load_atom, tOtO_i)
o_tiled_tmem_store = tcgen05.make_tmem_copy(o_tmem_store_atom, tOtO_i)
o_thr_load = o_tiled_tmem_load.get_slice(sfw_idx)
o_thr_store = o_tiled_tmem_store.get_slice(sfw_idx)
tTMEM_LOADtO = o_thr_load.partition_S(tOtO_i)
tTMEM_LOADcO = o_thr_load.partition_D(tOcO_i)
tTMEM_STOREtO = o_thr_store.partition_D(tOtO_i)
o_col_tiles = self.pv_mma_tiler[1] // corr_tile_size
# --- C2: Per-thread row state (persist across KV tiles) ---
row_max = -cutlass.Float32.inf
row_sum = cutlass.Float32(0.0)
# --- C3: QK scale = 1/sqrt(HEAD_DIM) * log2(e) for exp2 ---
scale = self.scale_softmax_log2
# =============================================================
# Per-KV-tile online softmax loop
# =============================================================
for kt in range(n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# Load S from TMEM (FP32, QK C-fragment layout)
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
# --- C4: Compute tile_max via .reduce(MAX) ---
old_row_max = row_max
row_max = tTMEM_LOADrS.load().reduce(cute.ReductionOp.MAX, row_max, 0)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = cutlass.Float32(0.0)
# --- C5: Compute rescale factor ---
acc_scale = cute.math.exp2(old_row_max - row_max_safe, fastmath=False)
# --- C6: Rescale O in TMEM (load O, multiply by acc_scale, store O) ---
if kt > 0:
tTMrO = cute.make_rmem_tensor((tTMEM_LOADcO.shape, o_col_tiles), self.qk_acc_dtype)
for i in range(o_col_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]))
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout)
tTMEM_STOREtO_i = cute.make_tensor(tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout)
cute.copy(o_tiled_tmem_load, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * acc_scale
cute.copy(o_tiled_tmem_store, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Rescale row_sum
row_sum = row_sum * acc_scale
# --- C7: Compute P = exp2((S - row_max_safe) * scale) ---
minus_row_max_scale = (cutlass.Float32(0.0) - row_max_safe) * scale
# Register bridge (FMHA pattern: FP32 backing + BF16 view)
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
# Scale S, compute exp2, store through register bridge
for j in range(frg_cnt):
for k in cutlass.range(cute.size(tTMEM_LOADrS_frg, mode=[0]), vectorize=True):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale + minus_row_max_scale
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=False)
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
# Store P to TMEM
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# --- C8: Row sum accumulation (CUTLASS FMHA packed f32x2 pattern) ---
# P values still in tTMEM_LOADrS registers.
# 4 accumulators for 4 reduction_unroll columns.
local_row_sum_0 = (cutlass.Float32(0.0), cutlass.Float32(0.0))
local_row_sum_1 = (cutlass.Float32(0.0), cutlass.Float32(0.0))
local_row_sum_2 = (cutlass.Float32(0.0), cutlass.Float32(0.0))
local_row_sum_3 = (cutlass.Float32(0.0), cutlass.Float32(0.0))
reduction_unroll = 4
rfrg_tile = cute.size(tTMEM_LOADrS) // reduction_unroll
tTMEM_LOADrS_rfrg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(rfrg_tile))
for j in cutlass.range_constexpr(0, cute.size(tTMEM_LOADrS_rfrg, mode=[0]), 2):
local_row_sum_0 = cute.arch.add_packed_f32x2(
local_row_sum_0, (tTMEM_LOADrS_rfrg[j, 0], tTMEM_LOADrS_rfrg[j + 1, 0]))
local_row_sum_1 = cute.arch.add_packed_f32x2(
local_row_sum_1, (tTMEM_LOADrS_rfrg[j, 1], tTMEM_LOADrS_rfrg[j + 1, 1]))
local_row_sum_2 = cute.arch.add_packed_f32x2(
local_row_sum_2, (tTMEM_LOADrS_rfrg[j, 2], tTMEM_LOADrS_rfrg[j + 1, 2]))
local_row_sum_3 = cute.arch.add_packed_f32x2(
local_row_sum_3, (tTMEM_LOADrS_rfrg[j, 3], tTMEM_LOADrS_rfrg[j + 1, 3]))
local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_1)
local_row_sum_2 = cute.arch.add_packed_f32x2(local_row_sum_2, local_row_sum_3)
local_row_sum_0 = cute.arch.add_packed_f32x2(local_row_sum_0, local_row_sum_2)
tile_sum = local_row_sum_0[0] + local_row_sum_0[1]
row_sum = row_sum + tile_sum
# --- C9: Final normalization + epilogue TMA store ---
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
# C9: Normalize by 1/row_sum
inv_row_sum = cutlass.Float32(1.0) / row_sum
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile, 0,
const_expr(lambda x, s=inv_row_sum: x * s),
(0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)