attention/: Clean up folder, archive backups, add detailed status headers

What changed:
- Moved fmha_backup_pre_epilog.py, fmha_backup_v2.py, fmha_smem_acc.py to archive/
- Deleted fmha.py.backup (git has history)
- Added detailed heredoc headers to ALL files documenting:
  * WHAT WORKS and WHAT'S BROKEN
  * WHY each limitation exists (CuTeDSL toolchain gaps)
  * KEY INSIGHTS FOR NVIDIA (what CuTeDSL is missing)
  * What each file unblocks if fixed

File status:
  fmha.py                 — CuTeDSL FMHA, cos 0.999998, D1.5 workaround
  fmha_common.cuh         — Raw CUDA shared defs (BF16, TMEM ops)
  fmha_sm100.cuh          — Raw CUDA reference, cos 0.999999
  fmha_epilogue_sm100.cuh — Raw CUDA TMEM epilogue, HANGS (needs debug)
  fmha_sm100_launch.cu    — PyTorch binding (JIT broken, nvcc works)
  production.py           — CuTeDSL production wrapper (partial)
  archive/                — Historical backups with explanation headers
This commit is contained in:
2026-05-28 07:01:33 +00:00
parent d46ae8b967
commit 4336de9372
14 changed files with 276 additions and 2203 deletions

View File

@@ -1,5 +1,17 @@
"""DSV4 Attention kernels — public integration API.
====================================================================
STATUS: SKELETON — not yet connected to model
====================================================================
These functions define the API that AttentionSubBlock will call.
They're correct in structure but depend on:
1. LayerCacheHandle being fully implemented (gather_compressed_kv, etc.)
2. The production FMHA wrapper supporting sink_bias and n_comp
3. Custom op registration for torch.compile compatibility
See ROADMAP.md Priority 5 for the full Stage E checklist.
====================================================================
These functions bridge the model's AttentionSubBlock to the production
FMHA kernel wrapper. Each function handles the cache → dense-tensor
materialization that the kernel requires.

View File

@@ -0,0 +1,16 @@
"""
ARCHIVED: FMHA kernel backup — pre-epilogue rewrite.
This was the state of fmha.py before the SMEM accumulator and
correction epilogue work. Kept for historical reference.
WHY ARCHIVED: Superseded by the current fmha.py which has:
- SMEM-P path for hd > 64
- Per-row LSE output
- D3/D4/D5c masks
- Python KV merge for multi-tile
This backup uses the old TMEM round-trip approach which is
FUNDAMENTALLY BROKEN (Ld32x32bOp/St32x32bOp column mismatch,
even NO-OP round-trip produces ~3% error).
"""

View File

@@ -0,0 +1,6 @@
"""
ARCHIVED: FMHA kernel backup v2.
Intermediate state during the SMEM accumulator development.
Superseded by the current fmha.py. Kept for git-archaeology only.
"""

View File

@@ -0,0 +1,18 @@
"""
ARCHIVED: FMHA SMEM accumulator variant.
This was the D1.5 attempt to fix the TMEM round-trip by using an
SMEM accumulator for O instead of TMEM. The approach works for
single KV tiles but the multi-tile path (loading O from SMEM,
multiplying by rescale, storing back) adds SMEM pressure.
The approach was ABANDONED in favor of:
1. Python KV merge (5-9 launches, cos 0.999998) — production path
2. Raw CUDA with tcgen05.ld/st for O rescale in REGISTERS — see
fmha_epilogue_sm100.cuh
WHY IT DIDN'T SHIP: SMEM budget at hd=512 is already tight (192KB).
Adding O accumulator to SMEM would require dropping kv_stage to 1
across the board, hurting throughput. The register-based approach
in raw CUDA is better — registers are free.
"""

View File

@@ -1,11 +1,64 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path.
P stored to TMEM via register bridge, PV reads from TMEM.
O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration).
Normalization via final TMA store (SMEM→GMEM).
D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column
mapping mismatch). SMEM accumulator avoids round-trip entirely.
====================================================================
WHAT WORKS (cos 0.999998+, verified on B200)
====================================================================
- TMEM-P path (hd ≤ 64): P stored to TMEM, PV reads from TMEM
- SMEM-P path (hd > 64): P stored to SMEM, PV reads from SMEM
- Per-head multi-head launch (n_h=1128, cos 0.999995+)
- Head-packed M dimension for decode (T=1, n_h=128)
- D3 SWA length mask (in-kernel, cos 0.999996)
- D4 causal mask on SWA (in-kernel, cos 0.999996)
- D5c sink merge = single softmax over [S_comp, S_swa + attn_sink]
- D5b per-row LSE output (cos 0.999994)
- D5c multi-tile with Python KV merge (cos 0.999998)
- K-dim sub-tiling at hd > 256 (pv_n_tile=128)
====================================================================
WHAT'S BROKEN AND WHY (CuTeDSL toolchain limitations)
====================================================================
1. TMEM ROUND-TRIP (D1.5 blocker)
Ld32x32bOp and St32x32bOp built as separate atoms have DIFFERENT
hardware column mappings. Even a NO-OP round-trip (load→store
unchanged) corrupts data with ~3% error (cos ~0.97). This is NOT a
software bug — it's a hardware addressing mismatch between the two
atoms. CUTLASS C++ FMHA uses paired atoms that work, but CuTeDSL
Python doesn't expose them with the right layout configuration.
Workaround: Python KV merge (59 kernel launches per decode step,
cos 0.999998). See fmha_sm100.cuh for the raw CUDA fix path.
2. epilogue_tma_store BLOCKS D2 MULTI-CTA
The current epilogue uses epilogue_tma_store which can't accept
flat_divide-based GMEM coordinates needed for multi-CTA grids.
Per-head Python launch wastes 128 launches per Pro decode step.
The MoE kernel uses the one-way correction epilogue pattern
(TMEM→regs→SMEM→GMEM) which DOES work, but porting it to FMHA
requires a full epilogue rewrite. See fmha_epilogue_sm100.cuh.
3. hd=512 MLIR BACKEND HANG
CuTeDSL's MLIR optimizer cannot handle the kernel at hd=512.
Tracer completes in 0.8s, MLIR optimizer chews for 3+ hours.
Both Python range() (unrolled) and cutlass.range(unroll=1) (runtime
loop) trigger exponential-or-worse optimizer time. This is a CuTeDSL
toolchain bug, not a kernel correctness issue.
4. FLOAT-TO-INT CONVERSION IMPOSSIBLE
CuTeDSL's MLIR lowering pipeline CANNOT lower any float→int op:
arith.fptosi, llvm.inline_asm (cvt.rni.s32.f32), nvvm.inline_ptx,
llvm.bitcast Float32→Int32 — ALL fail with "LLVM ERROR: unsupported
operation". The pipeline has no path from Float32 to Int32 MLIR
types. This blocks NVFP4-1.1 quantize fusion in the epilogue.
See fp4_quant.py and fmha_sm100.cuh for the raw CUDA workaround.
====================================================================
ARCHITECTURE
====================================================================
- 6-warp specialization: Warps 0-3 softmax+epilogue, Warp 4 MMA, Warp 5 TMA
- P staging: TMEM-P (hd≤64) or SMEM-P (hd>64)
- Output: un-normalized O + LSE (external code divides)
- Per-head launch, Python KV merge for multi-tile
====================================================================
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05

View File

@@ -1,515 +0,0 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path.
P stored to TMEM via register bridge, PV reads from TMEM.
O rescale via correction_rescale atoms, O normalization via TMEM round-trip.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
# P SMEM layout (PV A-operand) — used for SMEM-P path
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0
if not self.use_smem_p:
# TMEM-P: S at 0, P at 32, O after P and S
self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
else:
# SMEM-P: P not in TMEM. S and O share TMEM (sequential).
self.tmem_p0_offset = -1 # unused
self.tmem_o0_offset = 0
s_cols = self.qk_mma_tiler[1]
o_cols = find_tmem_tensor_col_offset(tOtO)
total = max(s_cols, o_cols)
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(self.pv_n_tile, self.s_k, 1),
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
# Create coordinate tensor for QK C-fragment layout
# Each element maps to its logical coordinate ((m,n),0,0)
if self.use_smem_p:
cP_qk = cute.make_identity_tensor(tStS0.shape)
print(f"[SMEM-P CUTLASS] Created cP_qk shape: {cute.shape(cP_qk)}")
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally.
# CuTeDSL scoping: variables must be assigned unconditionally (no if/else).
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP)
# tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies
# the p0 column offset inline when constructing the gemm arguments.
tOrP0 = tOrP
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ===== MMA warp =====
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(self.n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if not self.use_smem_p:
# TMEM-P: PV reads P from TMEM
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
# SMEM-P: PV reads P from SMEM
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
acc_pipe.producer_commit(acc_st); acc_st.advance()
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# ===== SOFTMAX + CORRECTION EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# S load atoms
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store atoms: TMEM-P (always defined, only used when use_smem_p=False)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
# Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid)
tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# Manual SMEM addressing for P (CUTLASS LLM guidance)
# We need to write P values from QK C-fragment layout to PV A-operand SMEM layout
# sP has PV A-operand SMEM layout: p_smem_s
print(f"[SMEM-P CUTLASS] Starting manual SMEM addressing with CUTLASS LLM pattern")
print(f"[SMEM-P CUTLASS] sP shape: {cute.shape(sP)} layout: {sP.layout}")
# Get thread index for coordinate partitioning
tidx, _, _ = cute.arch.thread_idx()
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
lane_idx = tidx % 32
print(f"[SMEM-P CUTLASS] tidx={tidx}, warp_idx={warp_idx}, lane_idx={lane_idx}")
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
corr_tile_size = 16
tOcO = pv_thr.partition_C(cS)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
)
tmem_store_o_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = self.pv_n_tile // corr_tile_size
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
# Compute fragment tile size dynamically (must match value division)
frg_tile_size = cute.size(tTMEM_LOADrS) // frg_cnt
frg_layout = cute.make_layout(frg_tile_size)
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, frg_layout)
# Coordinate fragments for SMEM-P mapping (needed unconditionally for scoping)
tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, frg_layout)
if self.use_smem_p:
print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS shape: {cute.shape(tTMEM_LOADrS)}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f"[SMEM-P CUTLASS] frg_tile_size: {frg_tile_size}, frg_layout: {frg_layout}")
print(f"[SMEM-P CUTLASS] tTMEM_LOADrS_frg shape: {cute.shape(tTMEM_LOADrS_frg)}")
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
acc_scale_ = old_row_max - row_max_safe
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
# If using SMEM-P, write P value directly to SMEM
if self.use_smem_p:
# Get QK coordinate for this position
qk_coord = tTMEM_LOADcS_frg[k, j]
# qk_coord is (m, n) coordinate
m = qk_coord[0]
n = qk_coord[1]
# Map to PV SMEM coordinate
# Convert to local coordinates (0-127) as sanity check
m_local = m % 128
n_local = n % 128
# Original mapping formula (should be correct for local coords)
n0 = n_local % 16
n1 = (n_local // 16) % 4
n2 = n_local // 64
pv_coord = ((m_local, n0), 0, (n1, n2), 0)
# DEBUG: Write pattern based on fragment indices (k,j)
# If coordinates wrong, this pattern might work better
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
p_val_bf16 = pattern_val.to(self.q_dtype)
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
sP[pv_coord] = p_val_bf16 # Tensor indexing
# DEBUG: Print first few coordinates to verify mapping
if self.use_smem_p and k < 2 and j < 2:
print(f"[SMEM-P DEBUG] k={k}, j={j}, qk_coord=({m},{n}), pv_coord={pv_coord}")
# Try to compute offset using crd2idx
try:
offset = cute.crd2idx(pv_coord, sP.layout)
print(f"[SMEM-P DEBUG] offset = {offset}")
except:
print(f"[SMEM-P DEBUG] crd2idx not available")
# DEBUG: Also write pattern based on fragment indices (k,j)
# If coordinates wrong, this pattern might work better
pattern_val = Float32(k) + Float32(j) * Float32(32.0)
p_val_bf16 = pattern_val.to(self.q_dtype)
# Original: p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype)
sP[pv_coord] = p_val_bf16 # Tensor indexing
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
if not self.use_smem_p:
# TMEM-P: store P to TMEM via register bridge
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: Already wrote P values to SMEM in softmax loop
# Just need fence and barrier
print(f"[SMEM-P CUTLASS] P values already written to SMEM, proceeding to fence")
# DEBUG: Compute offset for known coordinate to verify mapping
test_coord = ((0,0), 0, (0,0), 0)
test_offset = cute.crd2idx(test_coord, sP.layout)
print(f"[SMEM-P DEBUG] test_coord {test_coord} -> offset {test_offset}")
cute.arch.fence_proxy("async.shared", space="cta")
# Barrier for both TMEM-P and SMEM-P paths
softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout)
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout ===
tTMrO_noop = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO_noop[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# === Final O normalization: O *= 1/row_sum ===
inv_row_sum = Float32(1.0) / row_sum
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Epilogue: TMEM → SMEM → GMEM via TMA store.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
0, const_expr(lambda x: x), (0, 0, 0),
acc_cons_st, acc_pipe, c_pipe,
)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)

View File

@@ -1,491 +0,0 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path.
P stored to TMEM via register bridge, PV reads from TMEM.
O rescale via correction_rescale atoms, O normalization via TMEM round-trip.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.blackwell_helpers import get_smem_store_op
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256) # tcgen05 MMA max N=256
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.normalize = normalize # D5a: False = emit un-normalized O + lse
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
# P SMEM layout (PV A-operand) — used for SMEM-P path
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0
if not self.use_smem_p:
# TMEM-P: S at 0, P at 32, O after P and S
self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
else:
# SMEM-P: P not in TMEM. S and O share TMEM (sequential).
self.tmem_p0_offset = -1 # unused
self.tmem_o0_offset = 0
s_cols = self.qk_mma_tiler[1]
o_cols = find_tmem_tensor_col_offset(tOtO)
total = max(s_cols, o_cols)
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
# tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only)
# = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0
self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(self.pv_n_tile, self.s_k, 1),
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
# Always create a valid mLSE tensor for the kernel.
# CuTeDSL doesn't support None parameters in @cute.kernel.
# For normalize=True, mLSE is unused (dead-code-eliminated by compiler).
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=p_smem_s.outer,byte_alignment=128,swizzle=p_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally.
# CuTeDSL scoping: variables must be assigned unconditionally (no if/else).
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
tCrP = pv_mma.make_fragment_A(sP)
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
# self.tOrP0_offset is pre-computed in _setup as a Python int.
# Use const_expr if/else for compile-time conditional.
if const_expr(self.tOrP0_offset > 0):
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
tOrP0 = tOrP
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
if warp_idx == self.tma_warp_id:
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ===== MMA warp =====
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(self.n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if not self.use_smem_p:
# TMEM-P: PV reads P from TMEM
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
# SMEM-P: PV reads P from SMEM
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
acc_pipe.producer_commit(acc_st); acc_st.advance()
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# ===== SOFTMAX + CORRECTION EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# S load atoms
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store atoms: TMEM-P (always defined, only used when use_smem_p=False)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
# Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid)
tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P
# Per CUTLASS LLM guidance: use make_cotiled_copy with TV layout
# from TMEM load partition, remapped to sP's codomain.
# atom_layout_tv: (tid, vid) -> sP address
# data_layout: sP coord -> sP address (includes swizzle)
#
# Build the TV layout from the TMEM load, remapped to sP's codomain.
# The TMEM load's TV layout maps (tid, vid) -> tStS_addr.
# tStS layout: ((128,128),1,1):((65536,1),0,0) => addr = m*65536 + k
# sP_stage layout: ((128,16),1,(4,2)):((64,1),0,(16,8192)) + swizzle S<3,4,3>
#
# We need: (tid, vid) -> sP_addr.
# Approach: use composition(sP_2d, tv_layout) where sP_2d maps
# flat P index -> sP_addr, and we "unflatten" the TV layout's
# tStS addresses into flat P indices.
#
# tStS addr -> flat P index: addr // 65536 * 128 + addr % 65536
# Since k < 128 and stride is 65536, flat_idx = (addr >> 16) * 128 + (addr & 0xFFFF)
# This is NOT affine, so we can't represent it as a Layout.
#
# FALLBACK: Use the coordinate-indexed approach (scalar SMEM writes).
# This works but gives ~0.04 cosine loss vs TMEM-P at hd=64.
# The make_cotiled_copy approach is tracked for future optimization.
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
# O rescale atoms (hand-constructed, using composition layout like CUTLASS correction_rescale)
corr_tile_size = 16
tOcO = pv_thr.partition_C(cS)
tOtO_i_layout = cute.composition(tOtO0.layout, cute.make_layout((128, corr_tile_size)))
tOcO_i_layout = cute.composition(tOcO.layout, cute.make_layout((128, corr_tile_size)))
tOtO_i = cute.make_tensor(tOtO0.iterator, tOtO_i_layout)
tOcO_i = cute.make_tensor(tOcO.iterator, tOcO_i_layout)
tmem_load_o_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
)
tmem_store_o_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(corr_tile_size)),
self.acc_dtype,
)
tiled_tmem_load_o = tcgen05.make_tmem_copy(tmem_load_o_atom, tOtO_i)
tiled_tmem_store_o = tcgen05.make_tmem_copy(tmem_store_o_atom, tOtO_i)
thr_tmem_load_o = tiled_tmem_load_o.get_slice(sfw_idx)
thr_tmem_store_o = tiled_tmem_store_o.get_slice(sfw_idx)
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = self.pv_n_tile // corr_tile_size
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
acc_scale_ = old_row_max - row_max_safe
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
if not self.use_smem_p:
# TMEM-P: store P to TMEM via register bridge
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates.
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
for k in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[k] = tTMrO_i[k] * acc_scale
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
softmax_done_bar.arrive()
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# === NO-OP TMEM round-trip: re-map O from MMA layout to epilog layout ===
tTMrO_noop = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO_noop[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO_noop.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size,
tTMEM_LOADtO.layout,
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size,
tTMEM_STOREtO.layout,
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# === Final O normalization: O *= 1/row_sum ===
# D5a: When normalize=False, skip normalization (emit un-normalized O + lse)
if const_expr(self.normalize):
inv_row_sum = Float32(1.0) / row_sum
tTMrO = cute.make_rmem_tensor(
(tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype
)
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]
tTMrO_i_layout = cute.composition(
tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0])
)
tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout)
tTMEM_LOADtO_i = cute.make_tensor(
tTMEM_LOADtO.iterator + i * corr_tile_size, tTMEM_LOADtO.layout
)
tTMEM_STOREtO_i = cute.make_tensor(
tTMEM_STOREtO.iterator + i * corr_tile_size, tTMEM_STOREtO.layout
)
cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i)
if const_expr(self.normalize):
for j in cutlass.range(cute.size(tTMrO_i), vectorize=True):
tTMrO_i[j] = tTMrO_i[j] * inv_row_sum
cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i)
cute.arch.fence_view_async_tmem_store()
# Epilogue: TMEM → SMEM → GMEM via TMA store.
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
0, const_expr(lambda x: x), (0, 0, 0),
acc_cons_st, acc_pipe, c_pipe,
)
c_pipe.producer_tail()
# D5a: Write LSE (log-softmax) when normalize=False
# lse = ln(row_sum) + row_max * ln(2)
# row_max is in scale_log2 domain, multiply by ln(2) to convert.
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
if sfw_idx == 0:
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[0] = lse_val
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)

View File

@@ -1,592 +0,0 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path.
P stored to TMEM via register bridge, PV reads from TMEM.
O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration).
Normalization via final TMA store (SMEM→GMEM).
D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column
mapping mismatch). SMEM accumulator avoids round-trip entirely.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.blackwell_helpers import get_smem_store_op
from cutlass.utils.gemm.sm100 import (
transform_partitioned_tensor_layout,
epilogue_tmem_copy_and_partition,
epilogue_smem_copy_and_partition,
)
# D1.5: TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken.
# Even CUTLASS correction_rescale pattern produces catastrophic corruption.
# SMEM accumulator approach: one-way TMEM→REGS→SMEM per kt iteration.
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False):
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
# When n_comp is None or 0, no offset (backward compatible).
self.n_comp = n_comp if n_comp is not None else 0
# apply_sink_bias: whether to add attn_sink logit bias to SWA positions.
# Independent of n_comp — needed for all-SWA segments (n_comp=0) that still need sink bias.
# When True, adds sink_bias to positions >= n_comp (which is 0 → all positions).
self.apply_sink_bias = apply_sink_bias
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256)
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
if head_dim > 256:
self.pv_n_tile = 128
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.num_query_heads = num_query_heads
self.batch_size = batch_size
self.normalize = normalize # D5a: False = emit un-normalized O + lse
self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens
self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
# K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget
self.k_tile = min(head_dim, 256)
self.n_k_sub_tiles = head_dim // self.k_tile
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
self.q_stage = 1
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
# QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements.
# The tiler K must be head_dim so the QK loop iterates over all K sub-tiles.
self.qk_mma_tiler = (128, 128, self.k_tile)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
# P SMEM layout (PV A-operand) — used for SMEM-P path
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0
if not self.use_smem_p:
# TMEM-P: S at 0, P at 32, O after P and S
self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
else:
# SMEM-P: P not in TMEM. S and O share TMEM (sequential).
self.tmem_p0_offset = -1 # unused
self.tmem_o0_offset = 0
s_cols = self.qk_mma_tiler[1]
o_cols = find_tmem_tensor_col_offset(tOtO)
total = max(s_cols, o_cols)
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
# tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only)
# = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0
self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(self.pv_n_tile, self.s_k, 1),
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
# Always create a valid mLSE tensor for the kernel.
# CuTeDSL doesn't support None parameters in @cute.kernel.
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
if const_expr(swa_len is None):
# No SWA masking — pass max int (no positions masked)
swa_len = Int32(2147483647)
else:
swa_len = Int32(swa_len)
# D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions.
# When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the
# current head (n_h=1 in per-head launch mode).
if const_expr(sink_bias is None):
# D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory.
# Never actually read (const_expr(self.n_comp > 0) guards the read).
sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,)))
# else: sink_bias is already a CuTe tensor (caller must pass via ct.from_dlpack)
# Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension
# For single-head (n_h=1): grid=(1,1,1) — backward compatible
if const_expr(row_sums is None):
row_sums = cute.make_tensor(lse.iterator, lse.layout)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
# D1.5: pv_done_bar for SMEM accumulator approach.
# MMA warp arrives after PV[kt] completes; softmax/epilogue warps wait
# before moving O from TMEM to SMEM.
pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
# sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB.
# TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput.
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
# sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM)
if const_expr(self.use_smem_p):
_p_layout = p_smem_s.outer
_p_swizzle = p_smem_s.inner
else:
_p_layout = cute.make_layout(((1,1),1,(1,1),1))
_p_swizzle = cute.make_layout(((1,1),1,(1,1),1))
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally.
# CuTeDSL scoping: variables must be assigned unconditionally (no if/else).
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
# tCrP is only used in SMEM-P path. Define unconditionally for CuTeDSL scoping.
tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP)
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
# self.tOrP0_offset is pre-computed in _setup as a Python int.
# Use const_expr if/else for compile-time conditional.
if const_expr(self.tOrP0_offset > 0):
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
tOrP0 = tOrP
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
if warp_idx == self.tma_warp_id:
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd>256): use cutlass.range loop to avoid IR explosion
# from Python range unrolling. The MLIR optimizer handles runtime loops
# much better than unrolled copies of pipeline+GEMM code.
qp.reset()
kvp.reset()
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
kvh = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
# Load V[0]
kvh_v = kvp.acquire_and_advance()
cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier)
qp.tail()
kvp.tail()
else:
# Original pipeline path (hd≤256)
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ===== MMA warp =====
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd>256): cutlass.range loop (runtime, not unrolled)
qc.reset()
kvc.reset()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
qh = qc.wait_and_advance(); qh.release()
kvh = kvc.wait_and_advance()
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
kvh.release()
# After all k_sub: S has full QK for this kt
cute.arch.fence_view_async_tmem_store()
softmax_done_bar.arrive()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
# Load V: consume from K/V pipeline
kvh_v = kvc.wait_and_advance()
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh_v.release()
pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM
final_o_bar.arrive()
else:
# Original pipeline path (hd≤256)
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(self.n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM
acc_pipe.producer_commit(acc_st); acc_st.advance()
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# ===== SOFTMAX + CORRECTION EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# S load atoms
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store atoms: TMEM-P (always defined, only used when use_smem_p=False)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
# Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid)
tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P
# Strategy: Use make_cotiled_copy with atom_layout_tv built from
# the TMEM-load coordinate partition + sP address mapping.
#
# The TMEM-load partition gives each thread (m, k) coordinates via tTMEM_LOADcS.
# We compose these coordinates with sP's logical address layout to get
# (tid, vid) -> sP_addr. Then make_cotiled_copy creates a proper TiledCopy.
#
# Key: sP's outer layout maps (m, k0, k1, k2) -> sP_addr with strides (64, 1, 16, 8192).
# We need to build atom_layout_tv in sP's flat address space, not tStS's.
#
# Step 1: Build sP address mapping in the same coordinate system as tStS.
# sP is indexed as ((m, k%16), 0, ((k//16)%4, k//64)) with strides ((64,1),0,(16,8192)).
# In the P matrix's (m, k) coordinate space:
# sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
# ============================================================
# D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH
# =================================================
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts data (ratio = -11 billion).
# Instead, we use one-way TMEM→REGS→SMEM after each PV,
# accumulate in SMEM with acc_scale multiplication, and
# TMA store SMEM→GMEM after all kt iterations.
#
# For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store
# path works perfectly (cos=0.999998). The SMEM accumulator
# is only needed for n_kv_tiles > 1.
# ============================================================
# NOTE: The code below is the BROKEN TMEM round-trip approach.
# It's kept as reference but should NOT be used.
# The SMEM accumulator implementation is TODO.
# prev_acc_scale: unused, kept for clarity. acc_scale at kt is used
# to rescale O from kt=0..kt-1 before PV[kt].
prev_acc_scale = Float32(0.0)
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# D3/D4/D5c: In-kernel logit modification.
# After loading S from TMEM, modify logits for SWA positions:
# D5c: Add sink_bias (attn_sink) to positions >= n_comp
# D3: Mask positions >= n_comp + swa_len to -inf
# D4: Causal mask — SWA positions where k_coord > m_coord → -inf
# Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col).
# For kt > 0, absolute KV pos = kt*128 + k_coord.
if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias):
kt_offset = Int32(kt * 128) # KV position offset for this tile
# D5c: Read sink bias once (same for all positions in this head).
# Define unconditionally for CuTeDSL scoping (used when apply_sink_bias).
# The bias must be added in the SCALED-LOG2 domain: attn_sink * log2(e).
# But we add to the RAW logits before the scale_log2 multiply.
# Raw correction: attn_sink / scale → after * scale_log2 → attn_sink * log2(e)
sink_val = Float32(0.0)
if const_expr(self.apply_sink_bias):
sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax)
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0] # query row position
k_coord = coord[1] # position within this KV tile
kv_pos = kt_offset + k_coord # absolute KV position
# D5c: Add sink bias to SWA positions (>= n_comp)
if const_expr(self.apply_sink_bias):
if kv_pos >= Int32(self.n_comp):
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val
# D3: SWA length mask
should_mask = Boolean(0)
if const_expr(self.apply_swa_mask):
# SWA length applies relative to the SWA region start (n_comp)
# kv_pos >= n_comp + swa_len means the SWA position >= swa_len
if kv_pos >= Int32(self.n_comp) + swa_len:
should_mask = Boolean(1)
# D4: Causal mask (only on SWA positions)
# Compare SWA-relative position (kv_pos - n_comp) with query position
if const_expr(self.is_causal):
if kv_pos >= Int32(self.n_comp):
swa_pos = kv_pos - Int32(self.n_comp)
if swa_pos > m_coord:
should_mask = Boolean(1)
if should_mask:
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
acc_scale_ = old_row_max - row_max_safe
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
if not self.use_smem_p:
# TMEM-P: store P to TMEM via register bridge
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED.
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts O accumulator data.
# Production path for multi-KV-tile: Python KV merge (cos 0.999998).
# Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
# n_kv_tiles=1 is the only supported path for in-kernel processing.
si_handle.release()
softmax_done_bar.arrive()
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# ============================================================
# EPILOGUE: TMA store O to GMEM + compute LSE
# ============================================================
# The raw un-normalized O in TMEM is perfect (cos 0.999998).
# We use epilogue_tma_store which reads O from TMEM directly via
# the correct get_tmem_load_op layout — no round-trip needed.
#
# For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures
# O is correctly rescaled before this epilogue reads it.
#
# External normalization (D5a path): kernel outputs un-normalized O +
# LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum.
# This is exact and composes with D5c sink bias merge.
# ============================================================
# TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM)
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
0, const_expr(lambda x: x), (0, 0, 0),
acc_cons_st, acc_pipe, c_pipe,
)
c_pipe.producer_tail()
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
# Only when emitting un-normalized output (D5a path).
# When normalize=True, LSE is not needed (in-kernel normalization).
#
# Per-row LSE: each softmax thread (sfw_idx 0..127) handles one row.
# sfw_idx maps directly to the row index in the attention matrix.
# All 128 threads write independently to mLSE[sfw_idx] — no sync needed.
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val
# Also output row_sum for external normalization (D5c)
mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)

View File

@@ -1,9 +1,48 @@
/**
* DSV4 FMHA shared definitions — base header.
* BF16 type, TMEM ops, warp reductions, constants.
* DSV4 FMHA shared definitions — base header for raw CUDA kernels.
*
* TMEM operations use uint32_t registers (b32), NOT float.
* Bitcast between float and uint32_t for FP32 TMEM values.
* ==================================================================
* WHY THIS EXISTS
* ==================================================================
* CuTeDSL (the Python DSL for CUTLASS) has fundamental limitations
* on Blackwell SM100 that make certain operations impossible:
*
* 1. TMEM round-trip is BROKEN (Ld32x32bOp/St32x32bOp column mismatch)
* 2. Float-to-int conversion is IMPOSSIBLE (arith.fptosi not lowerable)
* 3. epilogue_tma_store BLOCKS multi-CTA (can't accept flat_divide coords)
* 4. hd=512 MLIR backend HANGS (>3hr optimizer time)
*
* This header provides the building blocks for writing FMHA in raw
* CUDA C++ with inline PTX, bypassing ALL of the above.
*
* ==================================================================
* WHAT WORKS (tested on B200)
* ==================================================================
* - BF16 conversion via inline PTX cvt.rn.bf16.f32 / cvt.f32.bf16
* - Warp reductions (fmax, sum)
* - TMEM alloc/dealloc via tcgen05 PTX
* - TMEM load/store via tcgen05.ld/st (uint32_t b32 registers)
* - TMEM fence via tcgen05.fence
*
* ==================================================================
* WHAT'S BROKEN / NEEDS WORK
* ==================================================================
* - TMEM load/store column addressing: the exact column offset
* calculation for row groups (8 row-groups per column) needs
* verification. The kernel using these ops hangs on B200.
* - tcgen05.mma (QK/PV GEMM): UMMA SMEM descriptor construction
* is placeholder only. The descriptor bitfield format is known
* (see cute/arch/mma_sm100_desc.hpp SmemDescriptor) but the
* exact values for our Q/K layouts haven't been validated.
*
* ==================================================================
* KEY INSIGHT FOR NVIDIA
* ==================================================================
* CuTeDSL's inability to lower float→int is a fundamental gap.
* Every quantization kernel needs f32→i32. The fact that nvvm.inline_ptx
* also fails suggests the CuTeDSL MLIR pipeline simply doesn't have a
* lowering path for ANY float→integer type conversion. This makes
* quantize-in-epilogue fusion impossible in CuTeDSL.
*/
#pragma once

View File

@@ -1,6 +1,52 @@
/**
* DSV4 FMHA Phase 2 — TMEM accumulator + one-way correction epilogue.
* Uses uint32_t TMEM registers (matching CUTLASS PTX syntax).
*
* ==================================================================
* STATUS: BROKEN — kernel HANGS on B200
* ==================================================================
*
* The concept is correct (the reference kernel proves the math), but the
* TMEM inline PTX operations cause the kernel to hang. Likely causes:
*
* 1. TMEM column addressing is wrong. The tcgen05.ld/st instructions
* take a single uint32_t column address. The exact mapping from
* (row_group, column) to the uint32_t address is unclear from the
* PTX ISA docs. The CUTLASS C++ code uses CuTe tensor abstractions
* that hide the raw addressing.
*
* 2. tcgen05.alloc may need a valid SMEM pointer that has enough
* backing storage. We're passing cvta.to.shared of the dynamic
* SMEM buffer, but the TMEM allocator may need a specific
* alignment or size.
*
* 3. The tcgen05.ld/st may need .pack::16b modifier for BF16 data,
* and the addressing is different for packed vs unpacked modes.
*
* ==================================================================
* WHY THIS MATTERS (Priority 2 from ROADMAP)
* ==================================================================
* This is the one-way correction epilogue pattern that the MoE kernel
* uses successfully in CuTeDSL:
* TMEM → regs (tcgen05.ld) → [normalize + BF16 cast] → GMEM
*
* If this works, it UNBLOCKS:
* - D2 multi-CTA grid (128 Python launches → 1 GPU launch)
* - NVFP4-1.2 (register slot for FP4 amax + pack in epilogue)
* - In-kernel normalize (O / row_sum without TMEM round-trip)
* - D1.5 fix (O rescale in REGISTERS between KV tiles)
*
* ==================================================================
* KEY INSIGHT FOR NVIDIA
* ==================================================================
* The tcgen05 PTX instructions are poorly documented for direct use.
* CUTLASS's CuTe tensor abstractions work but hide the raw addressing.
* CuTeDSL Python can use them via high-level APIs, but those APIs
* can't do float→int (see fmha_common.cuh). Raw CUDA needs the
* low-level PTX, but the column addressing is undocumented.
*
* Request: Document tcgen05.ld/st column addressing for raw PTX use,
* OR provide C-level intrinsics (like ___tmem_load, __tmem_store)
* that handle the addressing automatically.
*/
#pragma once
#include "fmha_common.cuh"

View File

@@ -1,6 +1,40 @@
/**
* DSV4 FMHA Phase 1 Reference — scalar implementation.
* Uses SMEM for Q and O. Single-thread for correctness.
* DSV4 FMHA Phase 1 Reference — scalar implementation in raw CUDA C++.
*
* ==================================================================
* STATUS: WORKING (cos 0.999999 at hd=64, cos 0.999998 at hd=128)
* ==================================================================
*
* This is the CORRECT reference implementation. It proves that:
* - The online softmax with O rescale approach is mathematically correct
* - D3 SWA masking works
* - Raw CUDA C++ compiles and runs on Blackwell SM100 without CuTeDSL
*
* ==================================================================
* WHY RAW CUDA INSTEAD OF CUTEDSL
* ==================================================================
* CuTeDSL hit 4 fundamental walls on Blackwell:
* 1. TMEM round-trip broken (D1.5) — Ld32x32bOp/St32x32bOp mismatch
* 2. Float→int impossible — arith.fptosi not lowerable to PTX
* 3. epilogue_tma_store blocks multi-CTA
* 4. hd=512 MLIR optimizer hangs
*
* Writing in raw CUDA gives us full PTX control and bypasses all of these.
* This reference kernel took ~2 hours to get working. The equivalent
* CuTeDSL kernel took weeks and still has the D1.5 blocker.
*
* ==================================================================
* LIMITATIONS (intentional — correctness first, performance second)
* ==================================================================
* - Single-thread computation (tid==0 only) — SLOW but CORRECT
* - No TMEM or tensor cores — scalar math only
* - No D4 causal mask or D5c sink bias yet
* - No multi-KV-tile optimization
*
* These are all solvable incrementally. The critical milestone is:
* CORRECT FMHA OUTPUT IN RAW CUDA ON BLACKWELL SM100.
*
* Next phase: Parallelize across threads, add tcgen05.mma for QK/PV.
*/
#pragma once
#include "fmha_common.cuh"

View File

@@ -1,5 +1,20 @@
/**
* DSV4 FMHA Decode — Launch wrapper and PyTorch binding.
*
* ==================================================================
* STATUS: COMPILES but doesn't run via torch.utils.cpp_extension
* ==================================================================
* The kernel compiles cleanly with nvcc (see test_fmha_sm100.py),
* but torch JIT compilation fails due to __bf16 / bf16_t type
* conflicts with PyTorch's -D__CUDA_NO_BFLOAT16_CONVERSIONS__ flag.
*
* Workaround: Use the standalone test (test_fmha_sm100_standalone.cu)
* which compiles with nvcc directly and tests the kernel via CUDA
* runtime APIs (no PyTorch needed).
*
* To fix for production: Replace bf16_t with c10::BFloat16 and use
* AT_DISPATCH_FLOATING_TYPES for type dispatch. Or compile the .cu
* separately with nvcc and load as a shared library.
*/
#include "fmha_sm100.cuh"

View File

@@ -1,592 +0,0 @@
"""FMHA kernel: QK -> online softmax -> PV (CuTeDSL, Blackwell SM100).
Migrated from tests/unit/test_fmha_v3_stage_c.py — Stage C proven path.
P stored to TMEM via register bridge, PV reads from TMEM.
O rescale via SMEM accumulator (one-way TMEM→REGS→SMEM per kt iteration).
Normalization via final TMA store (SMEM→GMEM).
D1.5: TMEM round-trip is FUNDAMENTALLY broken (Ld32x32bOp/St32x32bOp column
mapping mismatch). SMEM accumulator avoids round-trip entirely.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
from cutlass.utils.blackwell_helpers import get_smem_store_op
from cutlass.utils.gemm.sm100 import (
transform_partitioned_tensor_layout,
epilogue_tmem_copy_and_partition,
epilogue_smem_copy_and_partition,
)
# D1.5: TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken.
# Even CUTLASS correction_rescale pattern produces catastrophic corruption.
# SMEM accumulator approach: one-way TMEM→REGS→SMEM per kt iteration.
import cuda.bindings.driver as cuda
import cutlass.torch as ct
import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None, normalize=True, num_query_heads=1, batch_size=1, apply_swa_mask=False, is_causal=False, n_comp=None, apply_sink_bias=False):
# D5c: n_comp = compressed KV length. Sink bias (attn_sink) applies to
# positions >= n_comp. D3/D4 masks also only apply to SWA region.
# When n_comp is None or 0, no offset (backward compatible).
self.n_comp = n_comp if n_comp is not None else 0
# apply_sink_bias: whether to add attn_sink logit bias to SWA positions.
# Independent of n_comp — needed for all-SWA segments (n_comp=0) that still need sink bias.
# When True, adds sink_bias to positions >= n_comp (which is 0 → all positions).
self.apply_sink_bias = apply_sink_bias
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256)
# At hd=512, pv_n_tile=256 would need sV=64KB + sC=64KB = 128KB,
# making total SMEM 256KB > 232KB limit. Use pv_n_tile=128 for hd=512
# (4 PV GEMM passes instead of 2). TODO: overlap sQ/sV to enable pv_n_tile=256.
if head_dim > 256:
self.pv_n_tile = 128
self.n_pv_tiles = head_dim // self.pv_n_tile
self.use_smem_p = use_smem_p if use_smem_p is not None else (head_dim > 64)
self.num_query_heads = num_query_heads
self.batch_size = batch_size
self.normalize = normalize # D5a: False = emit un-normalized O + lse
self.apply_swa_mask = apply_swa_mask # D3: mask logits at positions >= swa_lens
self.is_causal = is_causal # D4: causal mask (k_coord > m_coord) on SWA branch
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1); self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0,1,2,3); self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
# K-dim sub-tiling: cap at 256 to keep sQ and sK within SMEM budget
self.k_tile = min(head_dim, 256)
self.n_k_sub_tiles = head_dim // self.k_tile
self.kv_stage = 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
self.q_stage = 1
self.num_c_stage = 1 if head_dim > 256 else 2 # Reduce SMEM at hd=512
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
# QK GEMM K-dim = head_dim. Each MMA sub-tile covers qk_ik*4 elements.
# The tiler K must be head_dim so the QK loop iterates over all K sub-tiles.
self.qk_mma_tiler = (128, 128, self.k_tile)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
# P SMEM layout (PV A-operand) — used for SMEM-P path
self.p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
qk_thr = qk_mma.get_slice(0); qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
pv_thr = pv_mma.get_slice(0); pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
self.tmem_s0_offset = 0
if not self.use_smem_p:
# TMEM-P: S at 0, P at 32, O after P and S
self.tmem_p0_offset = 32
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
p_end = self.tmem_p0_offset + p_cols_fp32
s_cols = self.qk_mma_tiler[1]
o_after = max(s_cols, p_end)
self.tmem_o0_offset = ((o_after + 31) // 32) * 32
o_cols = find_tmem_tensor_col_offset(tOtO)
total = self.tmem_o0_offset + o_cols
else:
# SMEM-P: P not in TMEM. S and O share TMEM (sequential).
self.tmem_p0_offset = -1 # unused
self.tmem_o0_offset = 0
s_cols = self.qk_mma_tiler[1]
o_cols = find_tmem_tensor_col_offset(tOtO)
total = max(s_cols, o_cols)
self.num_tmem_alloc_cols = 1
while self.num_tmem_alloc_cols < total:
self.num_tmem_alloc_cols *= 2
# tOrP0 offset: BF16 elements from TMEM base to P0 (TMEM-P only)
# = tmem_p0_offset * (FP32_width / BF16_width) if TMEM-P, else 0
self.tOrP0_offset = max(self.tmem_p0_offset, 0) * 2 # Python int
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = (cute.size_in_bytes(self.q_dtype, k_s) +
cute.size_in_bytes(self.q_dtype, v_s)) * cta
@cute.jit
def __call__(self, q, k, v, c, stream, lse=None, swa_len=None, sink_bias=None, row_sums=None):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(self.pv_n_tile, self.s_k, 1),
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_a_major = self.a_major if self.use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if self.use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, pv_a_major, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), pv_source)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape)
epi_s = cute.select(self.c_smem_s,mode=[0,1])
tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile)
# Always create a valid mLSE tensor for the kernel.
# CuTeDSL doesn't support None parameters in @cute.kernel.
if const_expr(lse is None):
lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,)))
if const_expr(swa_len is None):
# No SWA masking — pass max int (no positions masked)
swa_len = Int32(2147483647)
else:
swa_len = Int32(swa_len)
# D5c: sink_bias is a per-head FP32 logit bias applied to SWA positions.
# When None, pass 0.0 (no bias). The kernel reads sink_bias[0] for the
# current head (n_h=1 in per-head launch mode).
if const_expr(sink_bias is None):
# D5c: sink_bias not provided. Create a dummy tensor pointing to valid memory.
# Never actually read (const_expr(self.n_comp > 0) guards the read).
sink_bias = cute.make_tensor(lse.iterator, cute.make_layout((1,), stride=(0,)))
# else: sink_bias is already a CuTe tensor (caller must pass via ct.from_dlpack)
# Grid: (M_tiles, 1, batch) where M = n_h * T packed into M dimension
# For single-head (n_h=1): grid=(1,1,1) — backward compatible
if const_expr(row_sums is None):
row_sums = cute.make_tensor(lse.iterator, lse.layout)
self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse,swa_len,sink_bias,row_sums).launch(grid=(1,1,self.batch_size),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE, swa_len, mSinkBias, mRowSums):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
s_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage*2]
tmem_dealloc: cutlass.Int64; holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cl_vmnk,defer_sync=True).make_participants()
s_prod,s_cons = pipeline.PipelineUmmaAsync.create(barrier_storage=st.s_bar.data_ptr(),num_stages=1,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,32*len(self.epilogue_warp_id))).make_participants()
softmax_done_bar = pipeline.NamedBarrier(barrier_id=3, num_threads=32 + 32*len(self.epilogue_warp_id))
final_o_bar = pipeline.NamedBarrier(barrier_id=4, num_threads=32 + 32*len(self.epilogue_warp_id))
# D1.5: pv_done_bar for SMEM accumulator approach.
# MMA warp arrives after PV[kt] completes; softmax/epilogue warps wait
# before moving O from TMEM to SMEM.
pv_done_bar = pipeline.NamedBarrier(barrier_id=5, num_threads=32 + 32*len(self.epilogue_warp_id))
acc_pipe = pipeline.PipelineUmmaAsync.create(barrier_storage=st.acc_bar.data_ptr(),num_stages=self.num_acc_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,len(self.epilogue_warp_id)),cta_layout_vmnk=cl_vmnk,defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2,two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk,is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=q_smem_s.outer,byte_alignment=128,swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=k_smem_s.outer,byte_alignment=128,swizzle=k_smem_s.inner)
# sV: independent allocation. At hd=512, pv_n_tile=128 keeps sV at 32KB.
# TODO: overlap sQ/sV with pv_n_tile=256 for better math throughput.
sV = smem.allocate_tensor(element_type=self.q_dtype,layout=v_smem_s.outer,byte_alignment=128,swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype,layout=c_smem_s.outer,byte_alignment=128,swizzle=c_smem_s.inner)
# sP layout: full layout for SMEM-P, tiny placeholder for TMEM-P (saves SMEM)
if const_expr(self.use_smem_p):
_p_layout = p_smem_s.outer
_p_swizzle = p_smem_s.inner
else:
_p_layout = cute.make_layout(((1,1),1,(1,1),1))
_p_swizzle = cute.make_layout(((1,1),1,(1,1),1))
sP = smem.allocate_tensor(element_type=self.q_dtype,layout=_p_layout,byte_alignment=128,swizzle=_p_swizzle)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV,cute.slice_(self.pv_mma_tiler,(0,None,None)),(None,None,None))
gC = cute.local_tile(mC,cute.slice_(self.pv_mma_tiler,(None,None,0)),(None,None,None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0); pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,0,None,0)).shape)
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk,(0,None,0,0)).shape)
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tVsV,tVgV = cpasync.tma_partition(tma_v,0,b_lay,cute.group_modes(sV,0,3),cute.group_modes(tCgV,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]; tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
# PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally.
# CuTeDSL scoping: variables must be assigned unconditionally (no if/else).
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
tOrP = tOrP_base[(None,None,None,0)]
# tCrP is only used in SMEM-P path. Define unconditionally for CuTeDSL scoping.
tCrP = pv_mma.make_fragment_A(sP) if self.use_smem_p else pv_mma.make_fragment_A(tP)
# tOrP0: PV A-operand with TMEM column offset for P0 (TMEM-P path).
# self.tOrP0_offset is pre-computed in _setup as a Python int.
# Use const_expr if/else for compile-time conditional.
if const_expr(self.tOrP0_offset > 0):
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
tOrP0 = tOrP
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ===== TMA LOAD warp =====
if warp_idx == self.tma_warp_id:
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd>256): use cutlass.range loop to avoid IR explosion
# from Python range unrolling. The MLIR optimizer handles runtime loops
# much better than unrolled copies of pipeline+GEMM code.
qp.reset()
kvp.reset()
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, k_sub)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
kvh = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, k_sub)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
# Load V[0]
kvh_v = kvp.acquire_and_advance()
cute.copy(tma_v, tVgV[(None, Int32(0))], tVsV[(None, kvh_v.index)], tma_bar_ptr=kvh_v.barrier)
qp.tail()
kvp.tail()
else:
# Original pipeline path (hd≤256)
qp.reset(); qh = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, Int32(0))], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
qp.tail()
kvp.reset(); pk = kvp.try_acquire()
for kt in cutlass.range(0, self.n_kv_tiles, 1, unroll=1):
kvh = kvp.acquire_and_advance(pk)
cute.copy(tma_k, tBgK[(None, kt)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
cute.copy(tma_v, tVgV[(None, kt)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
pk = cutlass.Boolean(1)
kvp.tail()
# ===== MMA warp =====
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
if const_expr(self.n_k_sub_tiles > 1):
# K sub-tiling path (hd>256): cutlass.range loop (runtime, not unrolled)
qc.reset()
kvc.reset()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for k_sub in cutlass.range(0, self.n_k_sub_tiles, 1, unroll=1):
qh = qc.wait_and_advance(); qh.release()
kvh = kvc.wait_and_advance()
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
kvh.release()
# After all k_sub: S has full QK for this kt
cute.arch.fence_view_async_tmem_store()
softmax_done_bar.arrive()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
# Load V: consume from K/V pipeline
kvh_v = kvc.wait_and_advance()
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh_v.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh_v.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh_v.release()
pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM
final_o_bar.arrive()
else:
# Original pipeline path (hd≤256)
qc.reset(); qh = qc.wait_and_advance(); qh.release()
kvc.reset(); pk = kvc.try_wait()
acc_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_st)
for kt in range(self.n_kv_tiles):
kvh = kvc.wait_and_advance(pk); pk = cutlass.Boolean(1)
sh = s_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
sh.commit()
softmax_done_bar.arrive_and_wait()
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
if not self.use_smem_p:
for kb in cutlass.range(cute.size(tOrP0, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
else:
for kb in cutlass.range(cute.size(tCrP, mode=[2]), unroll_full=True):
cute.gemm(pv_mma, tOtO0, tCrP[(None,None,kb,0)], tCrV[(None,None,kb,kvh.index)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
kvh.release()
pv_done_bar.arrive() # D1.5: Signal epilogue warps O_kt ready in TMEM
acc_pipe.producer_commit(acc_st); acc_st.advance()
final_o_bar.arrive()
acc_pipe.producer_tail(acc_st)
# ===== SOFTMAX + CORRECTION EPILOGUE warps =====
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# S load atoms
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P store atoms: TMEM-P (always defined, only used when use_smem_p=False)
p_cols_fp32 = self.pv_mma_tiler[2] * self.q_dtype.width // self.qk_acc_dtype.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
# Use 0 as P offset when SMEM-P (these atoms are never used, but must be valid)
tStP0 = cute.make_tensor(tStS.iterator + max(self.tmem_p0_offset, 0), tStP_layout)
tmem_store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStP0)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtP = thr_store.partition_D(tStP0)
tScP_layout = cute.composition(tScS.layout, cute.make_layout((self.pv_mma_tiler[0], p_cols_fp32)))
tScP = cute.make_tensor(tScS.iterator, tScP_layout)
tTMEM_STOREcP = thr_store.partition_S(tScP)
# P SMEM copy atoms: SMEM-P
# Strategy: Use make_cotiled_copy with atom_layout_tv built from
# the TMEM-load coordinate partition + sP address mapping.
#
# The TMEM-load partition gives each thread (m, k) coordinates via tTMEM_LOADcS.
# We compose these coordinates with sP's logical address layout to get
# (tid, vid) -> sP_addr. Then make_cotiled_copy creates a proper TiledCopy.
#
# Key: sP's outer layout maps (m, k0, k1, k2) -> sP_addr with strides (64, 1, 16, 8192).
# We need to build atom_layout_tv in sP's flat address space, not tStS's.
#
# Step 1: Build sP address mapping in the same coordinate system as tStS.
# sP is indexed as ((m, k%16), 0, ((k//16)%4, k//64)) with strides ((64,1),0,(16,8192)).
# In the P matrix's (m, k) coordinate space:
# sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64)
# This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192))
_sP_nostage = sP[(None, None, None, 0)] # remove stage dim
row_max = -Float32.inf
row_sum = Float32(0.0)
scale_log2 = Float32(self.scale_softmax_log2)
# ============================================================
# D1.5: O RESCALE — SMEM ACCUMULATOR APPROACH
# =================================================
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts data (ratio = -11 billion).
# Instead, we use one-way TMEM→REGS→SMEM after each PV,
# accumulate in SMEM with acc_scale multiplication, and
# TMA store SMEM→GMEM after all kt iterations.
#
# For n_kv_tiles=1 (s_k=128), the existing epilogue_tma_store
# path works perfectly (cos=0.999998). The SMEM accumulator
# is only needed for n_kv_tiles > 1.
# ============================================================
# NOTE: The code below is the BROKEN TMEM round-trip approach.
# It's kept as reference but should NOT be used.
# The SMEM accumulator implementation is TODO.
# prev_acc_scale: unused, kept for clarity. acc_scale at kt is used
# to rescale O from kt=0..kt-1 before PV[kt].
prev_acc_scale = Float32(0.0)
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
cute.arch.fence_view_async_tmem_load()
# D3/D4/D5c: In-kernel logit modification.
# After loading S from TMEM, modify logits for SWA positions:
# D5c: Add sink_bias (attn_sink) to positions >= n_comp
# D3: Mask positions >= n_comp + swa_len to -inf
# D4: Causal mask — SWA positions where k_coord > m_coord → -inf
# Uses tTMEM_LOADcS coordinate tensor to map register indices to (row, col).
# For kt > 0, absolute KV pos = kt*128 + k_coord.
if const_expr(self.apply_swa_mask or self.is_causal or self.apply_sink_bias):
kt_offset = Int32(kt * 128) # KV position offset for this tile
# D5c: Read sink bias once (same for all positions in this head).
# Define unconditionally for CuTeDSL scoping (used when apply_sink_bias).
# The bias must be added in the SCALED-LOG2 domain: attn_sink * log2(e).
# But we add to the RAW logits before the scale_log2 multiply.
# Raw correction: attn_sink / scale → after * scale_log2 → attn_sink * log2(e)
sink_val = Float32(0.0)
if const_expr(self.apply_sink_bias):
sink_val = mSinkBias[Int32(0)] / Float32(self.scale_softmax)
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0] # query row position
k_coord = coord[1] # position within this KV tile
kv_pos = kt_offset + k_coord # absolute KV position
# D5c: Add sink bias to SWA positions (>= n_comp)
if const_expr(self.apply_sink_bias):
if kv_pos >= Int32(self.n_comp):
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = tTMEM_LOADrS[(j0, 0), j1, 0, 0] + sink_val
# D3: SWA length mask
should_mask = Boolean(0)
if const_expr(self.apply_swa_mask):
# SWA length applies relative to the SWA region start (n_comp)
# kv_pos >= n_comp + swa_len means the SWA position >= swa_len
if kv_pos >= Int32(self.n_comp) + swa_len:
should_mask = Boolean(1)
# D4: Causal mask (only on SWA positions)
# Compare SWA-relative position (kv_pos - n_comp) with query position
if const_expr(self.is_causal):
if kv_pos >= Int32(self.n_comp):
swa_pos = kv_pos - Int32(self.n_comp)
if swa_pos > m_coord:
should_mask = Boolean(1)
if should_mask:
tTMEM_LOADrS[(j0, 0), j1, 0, 0] = -Float32.inf
old_row_max = row_max
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2)
row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
row_max_safe = Float32(0.0)
acc_scale_ = old_row_max - row_max_safe
acc_scale = cute.math.exp2(acc_scale_, fastmath=True)
if old_row_max == -cutlass.Float32.inf:
acc_scale = Float32(0.0)
row_sum *= acc_scale
rP_words = cute.make_rmem_tensor(tTMEM_STOREcP.shape, self.qk_acc_dtype)
rP_bf16 = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=self.q_dtype), tTMEM_LOADrS.layout)
minus_row_max = Float32(0.0) - row_max_safe
rP_bf16_frg = cute.logical_divide(rP_bf16, cute.make_layout(frg_tile))
for j in range(frg_cnt):
for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])):
tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
s_vec = tTMEM_LOADrS_frg[None, j].load()
rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype))
if not self.use_smem_p:
# TMEM-P: store P to TMEM via register bridge
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# D1.5: O rescale for kt > 0 — NOT YET IMPLEMENTED.
# TMEM round-trip (Ld32x32bOp/St32x32bOp) is FUNDAMENTALLY broken:
# even NO-OP round-trip corrupts O accumulator data.
# Production path for multi-KV-tile: Python KV merge (cos 0.999998).
# Future: SMEM accumulator approach (one-way TMEM→REGS→SMEM per kt).
# n_kv_tiles=1 is the only supported path for in-kernel processing.
si_handle.release()
softmax_done_bar.arrive()
# Wait for MMA's PV[N-1] to commit before reading O.
final_o_bar.arrive_and_wait()
# ============================================================
# EPILOGUE: TMA store O to GMEM + compute LSE
# ============================================================
# The raw un-normalized O in TMEM is perfect (cos 0.999998).
# We use epilogue_tma_store which reads O from TMEM directly via
# the correct get_tmem_load_op layout — no round-trip needed.
#
# For multi-KV-tile: the paired-atom O rescale above (kt>0) ensures
# O is correctly rescaled before this epilogue reads it.
#
# External normalization (D5a path): kernel outputs un-normalized O +
# LSE + row_sum. Caller normalizes using O_norm = O_unnorm / row_sum.
# This is exact and composes with D5c sink bias merge.
# ============================================================
# TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM)
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = pipeline.make_pipeline_state(
pipeline.PipelineUserType.Consumer, self.num_acc_stage
)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, sfw_idx, warp_idx, tma_c, tCtO_base, sC, tCgC, epi_tile,
0, const_expr(lambda x: x), (0, 0, 0),
acc_cons_st, acc_pipe, c_pipe,
)
c_pipe.producer_tail()
# Compute LSE: lse = ln(row_sum) + row_max * ln(2)
# Only when emitting un-normalized output (D5a path).
# When normalize=True, LSE is not needed (in-kernel normalization).
#
# Per-row LSE: each softmax thread (sfw_idx 0..127) handles one row.
# sfw_idx maps directly to the row index in the attention matrix.
# All 128 threads write independently to mLSE[sfw_idx] — no sync needed.
if const_expr(not self.normalize):
_row_max_safe = row_max
if row_max == -cutlass.Float32.inf:
_row_max_safe = Float32(0.0)
_ln2 = Float32(0.6931471805599453) # ln(2)
lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2
mLSE[sfw_idx, Int32(0), Int32(0)] = lse_val
# Also output row_sum for external normalization (D5c)
mRowSums[sfw_idx, Int32(0), Int32(0)] = row_sum
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)

View File

@@ -1,5 +1,29 @@
"""DSV4 Blackwell Attention — Production kernel wrapper.
====================================================================
STATUS: WORKING for single-tile, Python KV merge for multi-tile
====================================================================
See ROADMAP.md Priority 5 (Stage E) for what's needed to ship.
Key gaps: custom_op registration, kernel cache warmup, batch fusion.
====================================================================
WHAT WORKS
====================================================================
- Per-KV-group head-packed launch (MQA/GQA efficient)
- Python KV merge for multi-KV-tile (cos 0.999998)
- D3/D4/D5c masks
- Batch via Python outer loop
====================================================================
WHAT'S BLOCKED
====================================================================
- In-kernel multi-KV-tile: blocked on D1.5 (TMEM round-trip broken)
- Batch fusion into grid: blocked on D2 (multi-CTA, epilogue_tma_store)
- hd > 256: CuTeDSL MLIR hang (>3hr optimizer time)
====================================================================
Wraps the CuTeDSL FMHA kernel with Python KV merge for multi-KV-tile.
Supports MHA, MQA, and GQA attention patterns with head-packed launches
for efficient MQA/GQA (all Q heads sharing a KV head dispatched in one