Files
nvfp4-megamoe-kernel/tests/archive/test_fmha_pipeline.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

354 lines
16 KiB
Python

"""
Stage B — FMHA-style KV-tile interleaved attention kernel.
Following CUTLASS FMHA reference architecture:
- Q: (seq_q, head_dim) — loaded once
- K, V: tiled over sequence dimension, V overwrites K in SMEM (FMHA trick)
- For each KV-tile:
1. TMA load K[tile] into sK SMEM
2. QK MMA: sQ @ sK^T → S in TMEM
3. Softmax: S → P in TMEM (with online softmax rescaling of O in TMEM)
4. V overwrites sK SMEM (after QK, K no longer needed)
5. PV MMA: P @ sV → O in TMEM (accumulate)
- Epilogue: divide O by row_sum, store to GMEM
This properly handles non-(128,128) PV because V SMEM always has the correct
data for the current KV-tile — it's loaded right before PV, not stale from
the beginning.
Warp layout:
Warp 0-3: Softmax (4 warps)
Warp 4: MMA
Warp 5: TMA load
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class FmhaPipelineKernel:
def __init__(self, qk_mma_tiler, pv_mma_tiler):
self.acc_dtype = Float32
self.qk_acc_dtype = Float32
self.q_dtype = BFloat16
self.o_dtype = BFloat16
self.c_dtype = BFloat16
self.qk_mma_tiler = qk_mma_tiler
self.pv_mma_tiler = pv_mma_tiler
self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.softmax_warp_ids = (0, 1, 2, 3)
self.mma_warp_id = 4
self.tma_warp_id = 5
self.threads_per_cta = 192
self.kv_stage = 2 # double-buffered KV
self.q_stage = 1
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.qk_mma_tiler[:2], qk_inst_k * 4)
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (*self.pv_mma_tiler[:2], pv_inst_k * 4)
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,))
self.epi_tile = self.pv_mma_tiler[:2]
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.pv_mma_tiler[1],
self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.num_ab_stage = 1
self.num_acc_stage = 1
self.num_c_stage = 2
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = o_cols
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) +
cute.size_in_bytes(self.q_dtype, b_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.qk_mma_tiler[:2],
tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
self.qk_acc_dtype, self.cta_group, self.pv_mma_tiler[:2],
tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
q, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
k, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_v, tma_tv = cute.nvgpu.make_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(
qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
tma_c, tma_tc, self.cluster_layout_vmnk,
self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q)
cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_v)
cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage * 2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator()
st = smem.allocate(SS)
q_prod, q_cons = pipeline.PipelineTmaUmma.create(
barrier_storage=st.q_bar.data_ptr(), num_stages=self.q_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
kv_prod, kv_cons = pipeline.PipelineTmaUmma.create(
barrier_storage=st.kv_bar.data_ptr(), num_stages=self.kv_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, 32 * len(self.softmax_warp_ids)),
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.softmax_warp_ids) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(
barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.softmax_warp_ids)))
tmem = utils.TmemAllocator(
st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.softmax_warp_ids[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=q_smem_s.outer, byte_alignment=128, swizzle=q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=k_smem_s.outer, byte_alignment=128, swizzle=k_smem_s.inner)
# V overwrites K SMEM (FMHA trick)
sV_ptr = cute.recast_ptr(sK.iterator, v_smem_s.inner)
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None, 0, None)), (None, None, None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0, None, None)), (None, None, None))
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0, None, None)), (None, None, None))
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None))
n_kv_tiles = cute.size(gK, mode=[3])
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ)
tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV)
tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ, 0, 3), cute.group_modes(tCgQ, 0, 3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK, 0, 3), cute.group_modes(tCgK, 0, 3))
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV, 0, 3), cute.group_modes(tCgV, 0, 3))
tAgQ = tAgQ[(None, 0, None, 0)]
tBgK = tBgK[(None, 0, None, 0)]
tVgV = tVgV[(None, 0, None, 0)]
tCrQ = qk_mma.make_fragment_A(sQ)
tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ═══ TMA LOAD WARP ═══
if warp_idx == self.tma_warp_id:
# Load Q once
q_prod.reset()
qh = q_prod.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
q_prod.tail()
# Load KV tiles: for each tile, load K then V
# K and V share SMEM, so V overwrites K after QK consumes it
kv_prod.reset()
peek = kv_prod.try_acquire()
for kt in cutlass.range(n_kv_tiles, unroll=1):
# Load K[tile]
kvh = kv_prod.acquire_and_advance(peek)
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
# Load V[tile] into the SAME SMEM (overwrites K after QK)
# Wait — we need QK to finish before V overwrites K.
# FMHA uses a SEPARATE pipeline entry for V. The MMA warp
# consumes K first (QK), then V (PV). The pipeline ordering
# ensures V doesn't overwrite K before QK is done.
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
peek = cutlass.Boolean(1)
if kvh.count + 1 < 2 * n_kv_tiles:
peek = kv_prod.try_acquire()
kv_prod.tail()
# ═══ MMA WARP ═══
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
q_cons.reset()
qh = q_cons.wait_and_advance()
qh.release()
kv_cons.reset()
peek = kv_cons.try_wait()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
for kt in range(n_kv_tiles):
# Wait for K[tile]
kvh = kv_cons.wait_and_advance(peek)
peek = cutlass.Boolean(1)
# ─── QK: Q @ K[tile]^T → S ───
s0_handle = mma_si_prod.acquire_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0,
tCrQ[(None, None, kb, 0)],
tCrK[(None, None, kb, kvh.index)],
tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
# ─── Wait for softmax: S → P done ───
s0_handle = mma_si_prod.acquire_and_advance()
# ─── Wait for V[tile] ───
vvh = kv_cons.wait_and_advance(peek)
peek = cutlass.Boolean(1)
# ─── PV: P @ V[tile] → O ───
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0,
tOrP0[(None, None, kb)],
tCrV[(None, None, kb, vvh.index)],
tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
kvh.release()
vvh.release()
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ═══ SOFTMAX WARPS ═══
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.softmax_warp_ids))
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen