D1.5: SMEM accumulator FMHA kernel — one-way TMEM→REGS→SMEM, no round-trip

TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY BROKEN.
Even NO-OP (multiply by 1.0) corrupts data.

New approach:
- PV always ACCUMULATE=False (fresh TMEM each kt)
- After pv_done_bar: one-way Ld32x32bOp load O_kt from TMEM→REGS
- Coordinate-indexed SMEM accumulation: sO_acc = acc_scale * sO_acc + O_kt
- sO_acc: FP32 [128, pv_n_tile] row-major (32KB at hd=64, 64KB at hd=128)
- Final: normalize, cast BF16, write to sC, TMA store to GMEM
This commit is contained in:
2026-05-27 04:53:40 +00:00
parent 81acf1593c
commit 6a621bdf64
2 changed files with 551 additions and 23 deletions

View File

@@ -1,25 +1,13 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
"""FMHA kernel: SMEM accumulator approach for multi-KV-tile O rescale.
SMEM accumulator approach for multi-KV-tile O rescale.
Instead of TMEM round-trip (which corrupts data), we move O from TMEM
to SMEM after each PV GEMM via one-way epilogue, and accumulate in SMEM.
This avoids the D1.5 TMEM round-trip bug entirely.
TMEM round-trip is FUNDAMENTALLY BROKEN (Ld32x32bOp/St32x32bOp column
mapping mismatch, even NO-OP corrupts). This kernel avoids it entirely.
Architecture:
- 6-warp specialization: 4 softmax+epilogue, 1 MMA, 1 TMA
- After PV[kt]: one-way TMEM→REGS→SMEM with acc_scale multiplication
- SMEM accumulator persists across kt iterations
- Final TMA store: SMEM→GMEM
Per-kt flow:
1. Softmax warps: compute P[kt], acc_scale[kt]
2. Signal softmax_done_bar
3. MMA warp: PV[kt] GEMM (ACCUMULATE=False, fresh TMEM)
4. Signal pv_done_bar
5. Softmax/epilogue warps: TMEM→REGS, acc_scale*O_acc + O_kt, REGS→SMEM
6. Repeat for next kt
7. After all kt: SMEM→GMEM via TMA
- 6-warp: 4 softmax+epilogue, 1 MMA, 1 TMA
- PV always ACCUMULATE=False (fresh TMEM each kt)
- After pv_done_bar: one-way TMEM->REGS load O_kt, accumulate in SMEM
- Final: normalize sO_acc -> sC (BF16) -> TMA store to GMEM
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
@@ -59,8 +47,17 @@ class FmhaKernel:
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.k_tile = min(head_dim, 256)
self.n_k_sub_tiles = head_dim // self.k_tile
self.kv_stage = 1 if head_dim > 128 else 2
self.q_stage = 1
self.num_c_stage = 1 if head_dim > 256 else 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
self.smem_acc_dtype = Float32
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
@@ -76,7 +73,7 @@ class FmhaKernel:
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
@@ -124,10 +121,11 @@ class FmhaKernel:
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, tcgen05.CtaGroup.ONE, (128,self.pv_n_tile), pv_source)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
@@ -148,4 +146,422 @@ class FmhaKernel:
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
# ... rest of kernel to be implemented
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
if const_expr(self.use_smem_p):
_p_layout = p_smem_s.outer
_p_swizzle = p_smem_s.inner
else:
_p_layout = cute.make_layout(((1,1),1,(1,1),1))
_p_swizzle = cute.make_layout(((1,1),1,(1,1),1))
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
# SMEM accumulator: FP32 [128, pv_n_tile] row-major
sO_acc_layout = cute.make_layout((128, self.pv_n_tile), stride=(self.pv_n_tile, 1))
sO_acc = smem.allocate_tensor(element_type=self.smem_acc_dtype, layout=sO_acc_layout, byte_alignment=128)
# Zero-initialize sO_acc
if warp_idx < self.mma_warp_id:
for i in cutlass.range(0, cute.size(sO_acc), unroll=1):
sO_acc[i] = Float32(0.0)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP)
if const_expr(self.tOrP0_offset > 0):
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
tOrP0 = tOrP
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
if warp_idx == self.tma_warp_id:
if const_expr(self.n_k_sub_tiles > 1):
qp.reset(); kvp.reset()
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
kvh = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
kvh_v = kvp.acquire_and_advance()
cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier)
qp.tail(); kvp.tail()
else:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ===== MMA warp =====
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
if const_expr(self.n_k_sub_tiles > 1):
qc.reset(); kvc.reset()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
qh = qc.wait_and_advance(); qh.release()
kvh = kvc.wait_and_advance()
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
kvh.release()
cute.arch.fence_view_async_tmem_store()
softmax_done_bar.arrive()
softmax_done_bar.arrive_and_wait()
# PV: ACCUMULATE=False for SMEM accumulator
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
kvh_v = kvc.wait_and_advance()
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh_v.release()
pv_done_bar.arrive()
final_o_bar.arrive()
else:
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(self.n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
# PV: ACCUMULATE=False for SMEM accumulator
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
pv_done_bar.arrive()
acc_pipe.producer_commit(acc_st); acc_st.advance()
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# ===== SOFTMAX + SMEM ACCUMULATOR EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# S load atoms
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# O load atoms: one-way TMEM->REGS read of O after PV
# Uses same Ld32x32bOp pattern from O's TMEM offset (tOtO0)
o_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_o_load = tcgen05.make_tmem_copy(o_load_atom, tOtO0)
thr_o_load = tiled_o_load.get_slice(sfw_idx)
tTMEM_LOADtO = thr_o_load.partition_S(tOtO0)
# Coordinate tensor for O: maps register positions to (row, col)
cO = cute.make_identity_tensor((self.qk_mma_tiler[0], self.pv_mma_tiler[1]))
tOcO = pv_thr.partition_C(cO)
tTMEM_LOADcO = thr_o_load.partition_D(tOcO)
# P store atoms (TMEM-P path)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
_sP_nostage = sP[(None, None, None, 0)]
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
# ============================================================
# MAIN LOOP: softmax + SMEM accumulator
# ============================================================
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
# --- Load S from TMEM ---
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# D3/D4/D5c: logit modification
if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias):
kt_offset = Int32(kt * 128)
sink_val = Float32(0.0)
if const_expr(self.apply_sink_bias):
sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax)
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]; k_coord = coord[1]
kv_pos = kt_offset + k_coord
if const_expr(self.apply_sink_bias):
if kv_pos >= Int32(self.n_comp):
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val
should_mask = Boolean(0)
if const_expr(self.apply_swa_mask):
if kv_pos >= Int32(self.n_comp) + swa_len:
should_mask = Boolean(1)
if const_expr(self.is_causal):
if kv_pos >= Int32(self.n_comp):
swa_pos = kv_pos - Int32(self.n_comp)
if swa_pos > m_coord:
should_mask = Boolean(1)
if should_mask:
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf
# --- Online softmax ---
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
acc_scale_ = old_row_max - row_max_safe
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# --- Compute P = softmax(S) ---
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
# --- Store P to TMEM or SMEM ---
if not self.use_smem_p:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]; k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
si_handle.release()
softmax_done_bar.arrive()
# --- Wait for PV[kt] to complete ---
pv_done_bar.arrive_and_wait()
# ========================================================
# SMEM ACCUMULATOR: load O_kt from TMEM, accumulate in SMEM
# ========================================================
# O_kt is in TMEM (PV with ACCUMULATE=False → fresh output)
# Load via one-way Ld32x32bOp (read-only, NO write-back to TMEM)
# Then: sO_acc = acc_scale * sO_acc + O_kt
# Using coordinate-indexed writes to sO_acc
# ========================================================
rO = cute.make_rmem_tensor(tTMEM_LOADcO.shape, self.qk_acc_dtype)
cute.copy(tiled_o_load, tTMEM_LOADtO, rO)
cute.arch.fence_view_async_tmem_load()
# Rescale existing sO_acc and add O_kt
# Use coordinate tensor to map each register to (row, col) in sO_acc
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcO[(j0, 0), j1, 0, 0]
row = coord[0]
col = coord[1]
old_val = sO_acc[row, col]
new_val = acc_scale * old_val + rO[(j0, 0), j1, 0, 0]
sO_acc[row, col] = new_val
# Wait for MMA's final signal
final_o_bar.arrive_and_wait()
# ============================================================
# EPILOGUE: normalize sO_acc, cast to BF16, TMA store to GMEM
# ============================================================
# sO_acc has the un-normalized O accumulated across all kt.
# Normalize: O_norm = O_unnorm / row_sum
# Then cast to BF16 and write to sC for TMA store.
# ============================================================
# Compute LSE and row_sum output
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
_ln2 = Float32(0.6931471805599453)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val
mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum
# Normalize and cast sO_acc -> sC
# Each thread handles its rows (sfw_idx maps to rows in sO_acc)
# sO_acc is (128, pv_n_tile), sC layout may differ
# Use coordinate-based write to sC via epi_tile
#
# For TMA store, we need data in sC in the layout expected by tma_c.
# We can't easily do coordinate-indexed writes to sC (swizzled layout).
# Instead: normalize in sO_acc, then bulk-copy to sC via SMEM copy.
#
# Simpler approach for n_kv_tiles=1 compatibility:
# For n_kv_tiles=1, we can use the existing epilogue_tma_store path.
# For n_kv_tiles>1, we use the sO_acc -> sC -> TMA path.
#
# For now: normalize sO_acc in-place, then copy to sC (BF16), then TMA store.
if const_expr(self.normalize):
# Normalize: divide by row_sum
# Each of the 128 softmax threads handles one row
inv_row_sum = Float32(1.0) / row_sum
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
row = sfw_idx
if row < Int32(128):
sO_acc[row, col] = sO_acc[row, col] * inv_row_sum
# Copy sO_acc (FP32) -> sC (BF16) using SMEM copy
# sC has swizzled layout from compute_epilogue_tile_shape,
# but we can write to it using the epi_tile coordinate mapping.
#
# Alternative: use TMA store directly from a properly laid out SMEM buffer.
# The simplest correct approach: use epilogue_tma_store but read from
# a SMEM buffer instead of TMEM.
#
# For the MVP, we use the existing sC layout and write via
# the epi_tile partition that TMA expects.
# Use epilogue_tma_store to write sO_acc -> GMEM
# But epilogue_tma_store reads from TMEM, not SMEM.
# We need a different TMA store path.
#
# Simplest: use cpasync.bulk_copy (SMEM->GMEM) with sC as source.
# First: copy sO_acc -> sC (FP32->BF16 cast)
# Then: TMA bulk copy sC -> GMEM
#
# Write to sC row by row using the epi_tile coordinate mapping.
# The epi_tile shape is derived from cta_tile_shape_mnk.
# For hd=64 with pv_n_tile=64: epi_tile covers (128, 64).
# For each row assigned to this thread, cast FP32->BF16
# and write to sC using flat index mapping.
# sC is 2-stage: sC[128, pv_n_tile, num_c_stage] in BF16
c_stage0 = cute.slice_(sC, (None, None, 0)) # First stage of sC
for col in cutlass.range(0, self.pv_n_tile, unroll=1):
row = sfw_idx
if row < Int32(128):
c_stage0[row, col] = sO_acc[row, col].to(self.o_dtype)
# TMA store sC -> GMEM
cute.arch.fence_proxy("async.shared", space="cta")
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
c_pipe.producer_acquire()
cute.copy(tma_c, c_stage0, tCgC[(None, None, Int32(0))])
c_pipe.producer_commit()
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)

112
tests/unit/test_smem_acc.py Normal file
View File

@@ -0,0 +1,112 @@
"""
Test SMEM accumulator FMHA kernel: multi-KV-tile with in-kernel O accumulation.
No Python KV merge needed — the kernel handles acc_scale internally.
"""
import torch, math, sys
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha_smem_acc import FmhaKernel
def test_smem_acc(hd=64, s_k=256, use_smem_p=False, normalize=True):
m = 128
n_kv_tiles = s_k // 128
torch.manual_seed(42)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# FP32 reference
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_norm = (attn_exp / attn_sum) @ v.float()
ref_unnorm = attn_exp @ v.float()
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
row_sums_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=s_k, use_smem_p=use_smem_p, normalize=normalize)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Compile
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
print(f' hd={hd}, s_k={s_k} ({n_kv_tiles} KV tiles, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE, mRS)
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v[:, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor.zero_()
row_sums_tensor.zero_()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
mRS = ct.from_dlpack(row_sums_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(row_sums_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE, mRS)
torch.cuda.synchronize()
c[:, v_start:v_end, :] = c_tile
out = c[:, :, 0].float()
if normalize:
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_norm.flatten().unsqueeze(0)
).item()
ref = ref_norm
else:
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
).item()
ref = ref_unnorm
status = "PASS" if cos >= 0.99 else "FAIL"
print(f' hd={hd}, s_k={s_k} ({n_kv_tiles} tiles): cos {cos:.6f} {status}')
return cos
def test():
print("=== SMEM Accumulator FMHA: In-Kernel Multi-KV-Tile O Accumulation ===\n")
# Single KV tile (s_k=128): should work like fmha.py
print("--- Single KV tile (s_k=128) ---")
test_smem_acc(64, 128)
test_smem_acc(128, 128)
# Multi KV tile: the SMEM accumulator approach should handle this correctly
print("\n--- Multi KV tile (s_k=256+) ---")
test_smem_acc(64, 256)
test_smem_acc(64, 384)
test_smem_acc(64, 512)
test_smem_acc(128, 256)
if __name__ == '__main__':
test()