Files
nvfp4-megamoe-kernel/archived_plans/OLD_README_STUFF.md
2026-06-03 07:37:13 +00:00

4.3 KiB
Raw Permalink Blame History

Our kernel design choices

Attention kernel (FmhaKernel)

6-warp specialization. Warps 03 handle softmax + correction + epilogue. Warp 4 is the MMA warp (QK + PV). Warp 5 is the TMA warp (Q/K/V loads, output store via pipeline).

P staging — two paths.

  • TMEM-P (hd ≤ 64): P stored to TMEM via register bridge (FP32 backing + BF16 view). PV reads P from TMEM. Used at the small head dims where QK C-fragment and PV A-fragment TMEM layouts agree.
  • SMEM-P (hd > 64): P written to SMEM via coordinate-indexed store using tTMEM_LOADcS to map register indices to (m, k) then into sP's subtile layout. PV reads P from SMEM with OperandSource.SMEM. Required because the QK ↔ PV TMEM layout disagreement at hd > 64 corrupts the round-trip.

Un-normalized O + LSE output. The kernel emits raw sum(P · V) and lse = ln(row_sum) + row_max · ln(2). External code (or the next kernel pass) divides. This composes — D5 merge, multi-tile rescale, and the inverse-RoPE → wo_a fuse all rely on it.

Per-head launch for multi-head. Python loop dispatches the single-CTA kernel once per head. Multi-CTA grid using flat_divide + tma_partition is the next refactor; the path is unblocked once the correction-epilog rewrite lands.

Head-packed M dimension for decode. Q reshaped to (n_h * T, hd, 1), all heads' rows packed into the 128-row M tile. Per-row softmax. At Pro decode (T=1, n_h=128) the M tile fits exactly.

K-dim sub-tiling at hd > 256. When head_dim > 256 (MMA instruction K-dim limit), Q and K split into n_k_sub_tiles = head_dim / 256 chunks along head_dim. QK accumulates in TMEM across sub-tiles (additive in logit space). The PV path uses pv_n_tile = 128 for hd > 256 to keep sV+sC within the 232 KB SMEM budget.

Sink bias as logit modification. D3 (SWA length mask), D4 (causal mask on SWA), and D5c (attention sink) all live in the same post-QK, pre-softmax in-register code. They read tTMEM_LOADcS to get (m, k) coordinates and modify tTMEM_LOADrS before the row-max reduction. The sink bias is added in the raw-logit domain as attn_sink / scale_softmax, then the existing * scale_log2 multiply converts to log2 space.

MoE kernel (FusedSwiGLUScaledGroupedGemmKernel)

7-warp specialization. Warps 03 epilogue (TMEM → registers → SMEM → GMEM with global scale, SwiGLU, clamp). Warp 4 MMA (tcgen05.mma.block_scale with SFA/SFB in TMEM). Warp 5 TMA load (A, B, SFA, SFB). Warp 6 scheduler (MoEStaticPersistentTileScheduler).

One-way TMEM → registers → SMEM → GMEM epilogue. Uses epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition (CUTLASS helpers, paired atoms). The SwiGLU + clamping math runs in registers between the t2r and r2s copies. No TMEM round-trip. This is the same pattern FMHA needs to adopt to fix the D1.5 blocker.

Subtile-level gate/up pairing. With granularity-8 interleaved L1 weights and epi_tile_n=8, even subtiles are gate and odd subtiles are up. silu_gate_buf register tensor carries the SiLU result across the subtile-pair boundary.

use_2cta_instrs conditional on tokens_sum ≥ 256 and even cluster_m. Decode (small M) stays 1-CTA; prefill/batched gets 2-CTA UMMA with multicast B (1.71.9× throughput).

Heterogeneous KV cache

  • State cache per request: fixed-size block holding (n_win SWA KV) and (uncompressed tail tokens awaiting compression). One block per request, lifetime managed by request scheduling.
  • Classical paged cache per request: variable blocks holding (k1 CSA compressed entries, k2 HCA compressed entries) per layer. k1 = lcm(m, m') / m = 32, k2 = lcm(m, m') / m' = 1. Block covers 128 original tokens.
  • Different layers can produce different KV cache sizes (CSA vs HCA vs SWA-only). The state cache + classical-pool split keeps PagedAttention-style alignment intact for the compressed pool.

NVFP4 throughout

  • Weights: NVFP4 (FP8 E4M3 scales, 16-element microblocks). Verified: sf_dtype, TMA element type, MMA kind (mxf4nvf4) all correct.
  • Activations: BF16 today, FP4 after NVFP4-1.x epilogue fusion lands.
  • KV cache: BF16 today; the FP8 (RoPE in BF16, NoPE in FP8) split per paper §2.3.4 is on the roadmap as NVFP4-2.
  • Indexer keys: stored FP4 in the cache today, but scored with a scalar CUDA-core kernel. Tensor-core FP4 scoring (paper §5.2.1) is a Stage F priority.