- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py - Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc. - Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda) - Moved PyTorch bridges to dsv4/ops/ - Moved nn.Module layers to dsv4layers/ - Moved reference implementations to dsv4/reference/ - Moved vendored CUTLASS code to vendored/ - Archived ~190 debug tests to tests/archive/ - Kept ~15 canonical tests in tests/unit/ - Updated all import paths - Added stubs for future components (model/, cache/, loader/) - Updated pyproject.toml: dsv4-inference package name
354 lines
16 KiB
Python
354 lines
16 KiB
Python
"""
|
|
Stage B — FMHA-style KV-tile interleaved attention kernel.
|
|
|
|
Following CUTLASS FMHA reference architecture:
|
|
- Q: (seq_q, head_dim) — loaded once
|
|
- K, V: tiled over sequence dimension, V overwrites K in SMEM (FMHA trick)
|
|
- For each KV-tile:
|
|
1. TMA load K[tile] into sK SMEM
|
|
2. QK MMA: sQ @ sK^T → S in TMEM
|
|
3. Softmax: S → P in TMEM (with online softmax rescaling of O in TMEM)
|
|
4. V overwrites sK SMEM (after QK, K no longer needed)
|
|
5. PV MMA: P @ sV → O in TMEM (accumulate)
|
|
- Epilogue: divide O by row_sum, store to GMEM
|
|
|
|
This properly handles non-(128,128) PV because V SMEM always has the correct
|
|
data for the current KV-tile — it's loaded right before PV, not stale from
|
|
the beginning.
|
|
|
|
Warp layout:
|
|
Warp 0-3: Softmax (4 warps)
|
|
Warp 4: MMA
|
|
Warp 5: TMA load
|
|
"""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
|
|
|
|
class FmhaPipelineKernel:
|
|
def __init__(self, qk_mma_tiler, pv_mma_tiler):
|
|
self.acc_dtype = Float32
|
|
self.qk_acc_dtype = Float32
|
|
self.q_dtype = BFloat16
|
|
self.o_dtype = BFloat16
|
|
self.c_dtype = BFloat16
|
|
self.qk_mma_tiler = qk_mma_tiler
|
|
self.pv_mma_tiler = pv_mma_tiler
|
|
self.use_2cta_instrs = False
|
|
self.epilog_sync_bar_id = 1
|
|
self.cluster_shape_mn = (1, 1)
|
|
self.cta_group = tcgen05.CtaGroup.ONE
|
|
self.softmax_warp_ids = (0, 1, 2, 3)
|
|
self.mma_warp_id = 4
|
|
self.tma_warp_id = 5
|
|
self.threads_per_cta = 192
|
|
self.kv_stage = 2 # double-buffered KV
|
|
self.q_stage = 1
|
|
|
|
def _setup(self, qk_mma, pv_mma):
|
|
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
self.qk_mma_tiler = (*self.qk_mma_tiler[:2], qk_inst_k * 4)
|
|
pv_inst_k = cute.size(pv_mma.shape_mnk, mode=[2])
|
|
self.pv_mma_tiler = (*self.pv_mma_tiler[:2], pv_inst_k * 4)
|
|
self.mma_tiler = self.qk_mma_tiler
|
|
|
|
self.cluster_layout_vmnk = cute.tiled_divide(
|
|
cute.make_layout((1, 1, 1)), (qk_mma.thr_id.shape,))
|
|
self.epi_tile = self.pv_mma_tiler[:2]
|
|
self.cta_tile_shape_mnk = (
|
|
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
|
|
self.pv_mma_tiler[1],
|
|
self.qk_mma_tiler[2])
|
|
self.c_layout = LayoutEnum.ROW_MAJOR
|
|
self.num_ab_stage = 1
|
|
self.num_acc_stage = 1
|
|
self.num_c_stage = 2
|
|
|
|
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
|
|
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
|
|
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
|
|
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, self.kv_stage)
|
|
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, self.num_c_stage)
|
|
|
|
qk_thr = qk_mma.get_slice(0)
|
|
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
|
s_cols = find_tmem_tensor_col_offset(tStS)
|
|
|
|
pv_thr = pv_mma.get_slice(0)
|
|
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
|
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
|
o_cols = find_tmem_tensor_col_offset(tOtO)
|
|
|
|
self.tmem_s0_offset = 0
|
|
self.tmem_p0_offset = 32
|
|
self.tmem_o0_offset = o_cols
|
|
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
|
|
|
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
|
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
|
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
|
|
|
|
a_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
|
|
b_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
|
|
self.num_tma_load_bytes = (
|
|
cute.size_in_bytes(self.q_dtype, a_smem) +
|
|
cute.size_in_bytes(self.q_dtype, b_smem)
|
|
) * cute.size(qk_mma.thr_id.shape)
|
|
|
|
@cute.jit
|
|
def __call__(self, q, k, v, c, stream):
|
|
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
|
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
|
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
|
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
|
self.c_layout = LayoutEnum.from_tensor(c)
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
|
|
self.qk_acc_dtype, self.cta_group, self.qk_mma_tiler[:2],
|
|
tcgen05.OperandSource.SMEM)
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
|
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
|
|
self.qk_acc_dtype, self.cta_group, self.pv_mma_tiler[:2],
|
|
tcgen05.OperandSource.TMEM)
|
|
self._setup(qk_mma, pv_mma)
|
|
|
|
q_smem = cute.slice_(self.q_smem_s, (None, None, None, 0))
|
|
k_smem = cute.slice_(self.k_smem_s, (None, None, None, 0))
|
|
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
|
|
|
|
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
|
|
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
|
|
q, q_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
|
tma_k, tma_tk = cute.nvgpu.make_tma_atom_B(
|
|
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
|
|
k, k_smem, self.qk_mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
|
|
tma_v, tma_tv = cute.nvgpu.make_tma_atom_B(
|
|
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
|
|
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
|
|
|
|
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
|
|
tma_c, tma_tc = cpasync.make_tiled_tma_atom(
|
|
cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
|
|
|
|
self._kernel(
|
|
qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
|
|
tma_c, tma_tc, self.cluster_layout_vmnk,
|
|
self.q_smem_s, self.k_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
|
|
).launch(grid=(1, 1, 1), block=[self.threads_per_cta, 1, 1], stream=stream)
|
|
|
|
@cute.kernel
|
|
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
|
|
tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
|
|
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
|
tidx, _, _ = cute.arch.thread_idx()
|
|
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
|
|
|
|
if warp_idx == self.tma_warp_id:
|
|
cpasync.prefetch_descriptor(tma_q)
|
|
cpasync.prefetch_descriptor(tma_k)
|
|
cpasync.prefetch_descriptor(tma_v)
|
|
cpasync.prefetch_descriptor(tma_c)
|
|
|
|
@cute.struct
|
|
class SS:
|
|
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage * 2]
|
|
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage * 2]
|
|
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
|
|
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
|
|
tmem_dealloc: cutlass.Int64
|
|
holding: cutlass.Int32
|
|
|
|
smem = utils.SmemAllocator()
|
|
st = smem.allocate(SS)
|
|
|
|
q_prod, q_cons = pipeline.PipelineTmaUmma.create(
|
|
barrier_storage=st.q_bar.data_ptr(), num_stages=self.q_stage,
|
|
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
|
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
|
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
|
).make_participants()
|
|
|
|
kv_prod, kv_cons = pipeline.PipelineTmaUmma.create(
|
|
barrier_storage=st.kv_bar.data_ptr(), num_stages=self.kv_stage,
|
|
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
|
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
|
|
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
|
|
).make_participants()
|
|
|
|
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
|
|
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
|
|
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
|
consumer_group=pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, 32 * len(self.softmax_warp_ids)),
|
|
).make_participants()
|
|
|
|
acc_pipe = pipeline.PipelineUmmaAsync.create(
|
|
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
|
|
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
|
|
consumer_group=pipeline.CooperativeGroup(
|
|
pipeline.Agent.Thread, len(self.softmax_warp_ids) * (2 if use_2cta else 1)),
|
|
cta_layout_vmnk=cl_vmnk, defer_sync=True)
|
|
|
|
tmem_bar = pipeline.NamedBarrier(
|
|
barrier_id=2,
|
|
num_threads=32 * len((self.mma_warp_id, *self.softmax_warp_ids)))
|
|
tmem = utils.TmemAllocator(
|
|
st.holding.ptr, barrier_for_retrieve=tmem_bar,
|
|
allocator_warp_id=self.softmax_warp_ids[0], is_two_cta=use_2cta,
|
|
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
|
|
|
|
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
|
|
|
|
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=q_smem_s.outer, byte_alignment=128, swizzle=q_smem_s.inner)
|
|
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=k_smem_s.outer, byte_alignment=128, swizzle=k_smem_s.inner)
|
|
# V overwrites K SMEM (FMHA trick)
|
|
sV_ptr = cute.recast_ptr(sK.iterator, v_smem_s.inner)
|
|
sV = cute.make_tensor(sV_ptr, v_smem_s.outer)
|
|
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
|
|
|
|
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None, 0, None)), (None, None, None))
|
|
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0, None, None)), (None, None, None))
|
|
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0, None, None)), (None, None, None))
|
|
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler, (None, None, 0)), (None, None, None))
|
|
n_kv_tiles = cute.size(gK, mode=[3])
|
|
|
|
qk_thr = qk_mma.get_slice(0)
|
|
pv_thr = pv_mma.get_slice(0)
|
|
|
|
tCgQ = qk_thr.partition_A(gQ)
|
|
tCgK = qk_thr.partition_B(gK)
|
|
tCgV = pv_thr.partition_B(gV)
|
|
tCgC = pv_thr.partition_C(gC)
|
|
|
|
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, 0, None, 0)).shape)
|
|
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ, 0, 3), cute.group_modes(tCgQ, 0, 3))
|
|
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0, None, 0, 0)).shape)
|
|
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK, 0, 3), cute.group_modes(tCgK, 0, 3))
|
|
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV, 0, 3), cute.group_modes(tCgV, 0, 3))
|
|
tAgQ = tAgQ[(None, 0, None, 0)]
|
|
tBgK = tBgK[(None, 0, None, 0)]
|
|
tVgV = tVgV[(None, 0, None, 0)]
|
|
|
|
tCrQ = qk_mma.make_fragment_A(sQ)
|
|
tCrK = qk_mma.make_fragment_B(sK)
|
|
tCrV = pv_mma.make_fragment_B(sV)
|
|
|
|
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
|
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
|
|
|
|
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
|
|
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
|
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
|
|
|
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
|
tOrP_base = pv_thr.make_fragment_A(tP)
|
|
tOrP = tOrP_base[(None, None, None, 0)]
|
|
tOrP0 = cute.make_tensor(
|
|
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
|
|
tOrP.layout)
|
|
|
|
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
|
|
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
|
|
|
|
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
|
|
|
# ═══ TMA LOAD WARP ═══
|
|
if warp_idx == self.tma_warp_id:
|
|
# Load Q once
|
|
q_prod.reset()
|
|
qh = q_prod.acquire_and_advance()
|
|
cute.copy(tma_q, tAgQ[(None, qh.count)], tAsQ[(None, qh.index)], tma_bar_ptr=qh.barrier)
|
|
q_prod.tail()
|
|
|
|
# Load KV tiles: for each tile, load K then V
|
|
# K and V share SMEM, so V overwrites K after QK consumes it
|
|
kv_prod.reset()
|
|
peek = kv_prod.try_acquire()
|
|
for kt in cutlass.range(n_kv_tiles, unroll=1):
|
|
# Load K[tile]
|
|
kvh = kv_prod.acquire_and_advance(peek)
|
|
cute.copy(tma_k, tBgK[(None, kvh.count)], tBsK[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
|
# Load V[tile] into the SAME SMEM (overwrites K after QK)
|
|
# Wait — we need QK to finish before V overwrites K.
|
|
# FMHA uses a SEPARATE pipeline entry for V. The MMA warp
|
|
# consumes K first (QK), then V (PV). The pipeline ordering
|
|
# ensures V doesn't overwrite K before QK is done.
|
|
cute.copy(tma_v, tVgV[(None, kvh.count)], tVsV[(None, kvh.index)], tma_bar_ptr=kvh.barrier)
|
|
peek = cutlass.Boolean(1)
|
|
if kvh.count + 1 < 2 * n_kv_tiles:
|
|
peek = kv_prod.try_acquire()
|
|
kv_prod.tail()
|
|
|
|
# ═══ MMA WARP ═══
|
|
if warp_idx == self.mma_warp_id:
|
|
tmem.wait_for_alloc()
|
|
|
|
q_cons.reset()
|
|
qh = q_cons.wait_and_advance()
|
|
qh.release()
|
|
|
|
kv_cons.reset()
|
|
peek = kv_cons.try_wait()
|
|
|
|
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
|
|
acc_pipe.producer_acquire(acc_prod_st)
|
|
|
|
for kt in range(n_kv_tiles):
|
|
# Wait for K[tile]
|
|
kvh = kv_cons.wait_and_advance(peek)
|
|
peek = cutlass.Boolean(1)
|
|
|
|
# ─── QK: Q @ K[tile]^T → S ───
|
|
s0_handle = mma_si_prod.acquire_and_advance()
|
|
qk_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
|
nblk = cute.size(tCrQ, mode=[2])
|
|
for kb in cutlass.range(nblk, unroll_full=True):
|
|
cute.gemm(qk_mma, tStS0,
|
|
tCrQ[(None, None, kb, 0)],
|
|
tCrK[(None, None, kb, kvh.index)],
|
|
tStS0)
|
|
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
|
cute.arch.fence_view_async_tmem_store()
|
|
s0_handle.commit()
|
|
|
|
# ─── Wait for softmax: S → P done ───
|
|
s0_handle = mma_si_prod.acquire_and_advance()
|
|
|
|
# ─── Wait for V[tile] ───
|
|
vvh = kv_cons.wait_and_advance(peek)
|
|
peek = cutlass.Boolean(1)
|
|
|
|
# ─── PV: P @ V[tile] → O ───
|
|
pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0)
|
|
nblk_pv = cute.size(tOrP0, mode=[2])
|
|
for kb in cutlass.range(nblk_pv, unroll_full=True):
|
|
cute.gemm(pv_mma, tOtO0,
|
|
tOrP0[(None, None, kb)],
|
|
tCrV[(None, None, kb, vvh.index)],
|
|
tOtO0)
|
|
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
|
|
|
|
kvh.release()
|
|
vvh.release()
|
|
|
|
acc_pipe.producer_commit(acc_prod_st)
|
|
acc_prod_st.advance()
|
|
acc_pipe.producer_tail(acc_prod_st)
|
|
|
|
# ═══ SOFTMAX WARPS ═══
|
|
if warp_idx < self.mma_warp_id:
|
|
tmem.allocate(self.num_tmem_alloc_cols)
|
|
tmem.wait_for_alloc()
|
|
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
|
|
sfw_idx = tidx % (32 * len(self.softmax_warp_ids))
|
|
|
|
tmem_load_atom = cute.make_copy_atom(
|
|
tcgen05.copy.Ld32x32bOp(tcgen |