4.3 KiB
Our kernel design choices
Attention kernel (FmhaKernel)
6-warp specialization. Warps 0–3 handle softmax + correction + epilogue. Warp 4 is the MMA warp (QK + PV). Warp 5 is the TMA warp (Q/K/V loads, output store via pipeline).
P staging — two paths.
- TMEM-P (hd ≤ 64): P stored to TMEM via register bridge (FP32 backing + BF16 view). PV reads P from TMEM. Used at the small head dims where QK C-fragment and PV A-fragment TMEM layouts agree.
- SMEM-P (hd > 64): P written to SMEM via coordinate-indexed store using
tTMEM_LOADcSto map register indices to(m, k)then intosP's subtile layout. PV reads P from SMEM withOperandSource.SMEM. Required because the QK ↔ PV TMEM layout disagreement at hd > 64 corrupts the round-trip.
Un-normalized O + LSE output. The kernel emits raw sum(P · V) and lse = ln(row_sum) + row_max · ln(2). External code (or the next kernel pass) divides. This composes — D5 merge, multi-tile rescale, and the inverse-RoPE → wo_a fuse all rely on it.
Per-head launch for multi-head. Python loop dispatches the single-CTA kernel once per head. Multi-CTA grid using flat_divide + tma_partition is the next refactor; the path is unblocked once the correction-epilog rewrite lands.
Head-packed M dimension for decode. Q reshaped to (n_h * T, hd, 1), all heads' rows packed into the 128-row M tile. Per-row softmax. At Pro decode (T=1, n_h=128) the M tile fits exactly.
K-dim sub-tiling at hd > 256. When head_dim > 256 (MMA instruction K-dim limit), Q and K split into n_k_sub_tiles = head_dim / 256 chunks along head_dim. QK accumulates in TMEM across sub-tiles (additive in logit space). The PV path uses pv_n_tile = 128 for hd > 256 to keep sV+sC within the 232 KB SMEM budget.
Sink bias as logit modification. D3 (SWA length mask), D4 (causal mask on SWA), and D5c (attention sink) all live in the same post-QK, pre-softmax in-register code. They read tTMEM_LOADcS to get (m, k) coordinates and modify tTMEM_LOADrS before the row-max reduction. The sink bias is added in the raw-logit domain as attn_sink / scale_softmax, then the existing * scale_log2 multiply converts to log2 space.
MoE kernel (FusedSwiGLUScaledGroupedGemmKernel)
7-warp specialization. Warps 0–3 epilogue (TMEM → registers → SMEM → GMEM with global scale, SwiGLU, clamp). Warp 4 MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM). Warp 5 TMA load (A, B, SFA, SFB). Warp 6 scheduler (MoEStaticPersistentTileScheduler).
One-way TMEM → registers → SMEM → GMEM epilogue. Uses epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition (CUTLASS helpers, paired atoms). The SwiGLU + clamping math runs in registers between the t2r and r2s copies. No TMEM round-trip. This is the same pattern FMHA needs to adopt to fix the D1.5 blocker.
Subtile-level gate/up pairing. With granularity-8 interleaved L1 weights and epi_tile_n=8, even subtiles are gate and odd subtiles are up. silu_gate_buf register tensor carries the SiLU result across the subtile-pair boundary.
use_2cta_instrs conditional on tokens_sum ≥ 256 and even cluster_m. Decode (small M) stays 1-CTA; prefill/batched gets 2-CTA UMMA with multicast B (1.7–1.9× throughput).
Heterogeneous KV cache
- State cache per request: fixed-size block holding
(n_win SWA KV)and(uncompressed tail tokens awaiting compression). One block per request, lifetime managed by request scheduling. - Classical paged cache per request: variable blocks holding
(k1 CSA compressed entries, k2 HCA compressed entries)per layer.k1 = lcm(m, m') / m = 32,k2 = lcm(m, m') / m' = 1. Block covers 128 original tokens. - Different layers can produce different KV cache sizes (CSA vs HCA vs SWA-only). The state cache + classical-pool split keeps PagedAttention-style alignment intact for the compressed pool.
NVFP4 throughout
- Weights: NVFP4 (FP8 E4M3 scales, 16-element microblocks). Verified:
sf_dtype, TMA element type, MMA kind (mxf4nvf4) all correct. - Activations: BF16 today, FP4 after NVFP4-1.x epilogue fusion lands.
- KV cache: BF16 today; the FP8 (RoPE in BF16, NoPE in FP8) split per paper §2.3.4 is on the roadmap as NVFP4-2.
- Indexer keys: stored FP4 in the cache today, but scored with a scalar CUDA-core kernel. Tensor-core FP4 scoring (paper §5.2.1) is a Stage F priority.