WIP: SMEM P path for PV (compiles but P write not implemented)

This commit is contained in:
2026-05-23 03:46:01 +00:00
parent 4c3239b9c0
commit 1d8b4aaab3

View File

@@ -1,10 +1,8 @@
"""
FMHA v3 Stage-C Multi-Tile with correction_epilog (paired atoms, no TMEM round-trip).
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
Key insight: hand-constructed Ld32x32bOp/St32x32bOp atoms for TMEM round-trip
introduce ~3% error (cos 0.973) because their TMEM column mapping differs from
get_tmem_load_op. The fix: use get_tmem_load_op + get_smem_store_op paired atoms
for a ONE-WAY trip: TMEM → reg (normalize) → SMEM, then TMA store SMEM → GMEM.
Stages A/B/C/D1. HEAD_DIM parameterized via constructor.
PV GEMM uses SMEM for A operand (P), eliminating TMEM layout mismatch.
P is computed in softmax warps and written to SMEM, then MMA reads from SMEM.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
@@ -16,9 +14,8 @@ import cutlass.torch as ct
import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, kv_stage=2):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
@@ -30,11 +27,12 @@ class FmhaKernel:
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.kv_stage = kv_stage; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
hd = self.head_dim
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
@@ -48,50 +46,40 @@ class FmhaKernel:
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_epilogue_smem_layout(self.o_dtype, self.c_layout, self.epi_tile, 2)
# TMEM: only S (QK result). P is in SMEM, O also in TMEM.
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0; self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
self.tmem_s0_offset = 0; self.tmem_o0_offset = 0 # S and O share TMEM (sequential)
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
total = max(s_cols, o_cols)
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
if self.num_tmem_alloc_cols > 512:
print(f"⚠️ TMEM BUDGET: {self.num_tmem_alloc_cols} cols (hd={hd})")
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) + cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# V FMHA layout: K-major (pv_n_tile, s_k) for PV GEMM
# When head_dim > 256, V_tile has pv_n_tile columns, not head_dim
v_n = self.pv_n_tile
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(v_n, self.s_k, 1),
stride=(1, v_n, v_n * self.s_k),
),
)
v_fmha = cute.make_tensor(v.iterator, cute.make_layout((v_n, self.s_k, 1), stride=(1, v_n, v_n * self.s_k)))
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), tcgen05.OperandSource.TMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), tcgen05.OperandSource.SMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
@@ -99,14 +87,14 @@ class FmhaKernel:
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_smem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
@@ -115,7 +103,6 @@ class FmhaKernel:
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
@@ -125,18 +112,15 @@ class FmhaKernel:
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
@@ -146,24 +130,16 @@ class FmhaKernel:
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
tCrV = pv_mma.make_fragment_B(sV); tCrP = pv_mma.make_fragment_A(sP)
# TMEM: S (QK result)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
# TMEM: O (PV result) — same offset as S (sequential, no overlap)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None,None,None,0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
@@ -190,6 +166,7 @@ class FmhaKernel:
for kt in range(self.n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
# QK GEMM → S in TMEM
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
@@ -197,9 +174,10 @@ class FmhaKernel:
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
# PV GEMM: P from SMEM, V from SMEM → O in TMEM
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
@@ -207,13 +185,12 @@ class FmhaKernel:
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# ===== SOFTMAX + CORRECTION EPILOGUE warps =====
# ===== SOFTMAX + EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# S load atoms
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
@@ -223,57 +200,21 @@ class FmhaKernel:
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store atoms
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
# P store: use PV A-fragment layout (tOrP0) as BF16, so PV reads correct TMEM columns.
# At hd>64, the QK C-fragment composition layout writes to different columns than
# the PV A-fragment reads. Using tOrP0's layout ensures consistency.
tStP0_bf16 = cute.make_tensor(tOrP0.iterator, tOrP0.layout)
tmem_store_atom_bf16 = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.q_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom_bf16, tStP0_bf16)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0_bf16)
# Coordinate tensor: derive from PV A-fragment partition
cP = cute.make_identity_tensor(tOrP0.shape)
tOcP = pv_thr.partition_A(cP) # Use PV thread slice for A-fragment
# Need to match tOrP0's layout for partition_S
tTMEM_STOREcP = thr_store.partition_S(tOrP0)
# P → SMEM copy (using PV A-operand thread partition)
p_s = cute.slice_(p_smem_s,(None,None,None,0))
tCrP_smem = pv_thr.partition_S(sP) # softmax thread → SMEM partition for P
tCrP_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype)
# Online softmax state
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
corr_tile_size = 16
tOcO = pv_thr.partition_C(cS)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
)
tmem_store_o_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = self.pv_n_tile // corr_tile_size
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
@@ -281,70 +222,29 @@ class FmhaKernel:
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
acc_scale_ = old_row_max - row_max_safe
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
# Store P to TMEM as BF16 using PV A-fragment layout
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16 = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.q_dtype)
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, rP_bf16, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
# Per-tile O rescale (hand-constructed atoms with logical_divide layout)
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Compute P = exp2(S * scale - row_max) and write to SMEM
# First compute in FP32, convert to BF16, write to SMEM
# TODO: proper SMEM write with P thread partition
# For now, just arrive at softmax_done_bar to unblock MMA
si_handle.release()
softmax_done_bar.arrive()
# Wait for MMA's PV[N-1] to commit before reading O.
# Wait for MMA's final PV
final_o_bar.arrive_and_wait()
# === Epilogue: TMEM → SMEM → GMEM via epilogue_tma_store ===
# Raw PV output (unnormalized) — cos 0.999998 without any TMEM round-trip.
# Normalization (÷row_sum) is applied at the Python level after kernel returns.
# Epilogue: raw PV output (unnormalized)
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
@@ -353,8 +253,5 @@ class FmhaKernel:
acc_cons_st, acc_pipe, c_pipe,
)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)