24 KiB
STAGE_D.md — FMHA Kernel Development
⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
The Workflow (DO NOT SKIP STEPS)
- Edit code in
~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py— this is the ONLY file for the FMHA kernel. - Commit and push:
cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master - Pull on B200:
sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" - Test on B200 using the test harness scripts — see README.md "Test Harness" section.
- Regression check: After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert.
The Rules (BURNED INTO THIS FILE)
- NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
- NEVER delete or modify the test files in
tests/unit/without explicit approval. - NEVER touch drivers, kernels, firmware, or system packages on the B200.
- CuTeDSL variables defined in
ifblocks are NOT visible in otherifblocks. Define all variables unconditionally before any branching. - Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
- After every P store to TMEM, call
cute.arch.fence_view_async_tmem_store(). Missing this produces NaN. tOrP0MUST include thetmem_p0_offsetcolumn offset. Useconst_exprfor the conditional.- PRINT THE SHAPES. ALWAYS. Reasoning about layouts without evidence is how we waste days.
Current Status (2026-05-24, 21:30 UTC)
✅ WORKING
| hd | n=128 cos | LSE err | Path | SMEM |
|---|---|---|---|---|
| 64 | 0.999998 | 0.000000 | TMEM-P | 128KB |
| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | 128KB |
| 256 | 0.999998 | 0.000000 | TMEM-P | 224KB |
❌ KNOWN ISSUES
- hd=512: MLIR compilation hangs. SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). Both
range()unrolled andcutlass.range(unroll=1)runtime loops trigger this. This is a CuTeDSL/MLIR toolchain limitation. - External k_sub merge doesn't work. k_sub segments are additive in logit space (S = S_0 + S_1), not attention weight space. The D5 merge formula does not apply. In-kernel k_sub accumulation is the only correct approach.
- O rescale (kt>0): Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. Guarded with
const_expr(n_kv_tiles > 1). - Kernel always outputs un-normalized O + LSE. No in-kernel normalization (eliminates TMEM round-trip error). External normalization:
O_norm = O_unnorm / row_sum.
Architecture
6-Warp Layout
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
Kernel Output
The kernel outputs un-normalized O + LSE via epilogue_tma_store:
- O_unnorm = sum(P * V) where P = exp(S * scale - row_max)
- LSE = ln(row_sum) + row_max * ln(2)
- External normalization: O_norm = O_unnorm / row_sum
- For D5 merge: use exp(LSE) directly in the merge formula
TMEM Layout
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95: P (64 FP32 via register bridge, BF16 view)
Col 128+: O (PV acc, 64+ FP32)
P Staging Paths
TMEM-P (hd≤64, also works at hd=128/256):
- P stored to TMEM via register bridge (FP32 backing + BF16 view)
- PV MMA reads P from TMEM via
tOrP0 - Works because QK C-fragment and PV A-fragment TMEM layouts agree at tested head dims
SMEM-P (hd>64):
- P written to SMEM via coordinate-indexed store
- Uses
tTMEM_LOADcSidentity tensor to get (m, k) coordinates - Maps to sP's subtile layout:
sP[(m_coord, k_sub), 0, (k_g1, k_g2)] - PV MMA reads P from SMEM via
tCrP = pv_mma.make_fragment_A(sP) - SMEM-P uses
OperandSource.SMEMfor PV MMA
Key Configuration
head_dim: constructor arg (64, 128, 256, 512)
pv_n_tile: min(head_dim, 256) # tcgen05 MMA max N=256
n_pv_tiles: head_dim // pv_n_tile
kv_stage: 1 if head_dim > 128 else 2 # Reduce SMEM at large hd
use_smem_p: head_dim > 64 # SMEM-P for hd>64
qk_mma_tiler: (128, 128, head_dim) # K-dim = head_dim (NOT hardcoded!)
Critical Bug Fix: qk_mma_tiler K-dim (2026-05-24)
ROOT CAUSE of hd>64 failure: qk_mma_tiler K-dim was hardcoded to qk_ik * 4 = 64 instead of head_dim.
This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd>64. The QK dot products were half the correct length, producing wrong attention scores.
Fix: self.qk_mma_tiler = (128, 128, self.head_dim) — one line change.
Impact: hd=128 went from cos 0.78 to 0.999997. hd=256 went from broken to 0.999998.
LESSON: The MMA tiler's K dimension must match the actual GEMM K dimension (head_dim), not the MMA instruction's K sub-tile size.
Lessons Learned (2026-05-24)
1. CuTeDSL MLIR Backend Cannot Handle Complex Pipeline Loops
The MLIR→PTX backend optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both unrolled (Python range) and runtime (cutlass.range unroll=1) loops trigger this. The Python tracer is fast (0.8s) because it just generates IR. The MLIR optimizer then chews on that IR for hours. Workaround: keep pipeline loops as simple as possible. Consider raw CUDA C++ for complex kernels.
2. External k_sub Merge is Mathematically Impossible
You CANNOT merge the outputs of two attention calls that compute softmax(Q_k0 @ K_k0^T)@V and softmax(Q_k1 @ K_k1^T)@V into softmax(Q @ K^T)@V. The k_sub segments are additive in LOGIT space (S = S_0 + S_1), but softmax is nonlinear. The D5 merge formula works because sparse and SWA attend over DIFFERENT token sets (additive in weight space). k_sub attends over the SAME tokens with PARTIAL dot products. These are fundamentally different operations. The only correct approach is in-kernel accumulation (S_0 + S_1 before softmax).
3. pv_n_tile Reduction is the Easiest SMEM Knob
At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. The cost is 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck. This is simpler than SMEM overlap (which requires CuTeDSL SmemAllocator changes) or Q tiling (which adds pipeline complexity).
4. Guard Dead Code with const_expr
CuTeDSL compiles BOTH branches of Python if statements, generating IR for code that will never execute at a given head_dim. Use const_expr(condition) to eliminate dead code at compile time. This is critical for:
- O rescale code (only needed when n_kv_tiles > 1)
- LSE computation (only needed when normalize=False)
- SMEM-P path (only needed when use_smem_p=True)
5. Don't Mix Python Loops and CuTeDSL Pipeline Operations
Python for loops unroll at trace time, creating N copies of the loop body in the IR. For pipeline acquire/release + TMA copy + GEMM, each copy is substantial. cutlass.range(unroll=1) creates a runtime loop with one copy of the body. For pipeline operations, prefer cutlass.range(unroll=1) to reduce IR size, even though the MLIR optimizer may still struggle with it.
6. The k_tile Parameter is the Key to hd=512
At hd=512, the kernel splits Q and K into sub-tiles of size k_tile=256 along the head_dim. Each sub-tile is loaded via TMA, processed by MMA, and accumulated. n_k_sub_tiles = head_dim // k_tile = 2. The k_tile parameter controls the sub-tile size and the number of iterations. k_tile must be ≤ 256 (MMA instruction K-dim limit) and must evenly divide head_dim.
SMEM Budget at Various hd
| hd | sQ | sK (kv_stage=1) | sV (pv_n_tile) | sP (SMEM-P) | sC | Total | Limit | Status |
|---|---|---|---|---|---|---|---|---|
| 64 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ |
| 128 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ |
| 256 | 64KB | 64KB | 64KB (256) | 0* | 32KB | 224KB | 232KB | ✅ |
| 512 | 64KB | 64KB | 32KB (128) | 0* | 32KB | 192KB | 232KB | ⚠️ Fits but MLIR hangs |
*TMEM-P path: sP allocation skipped (const_expr conditional) pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit SMEM
D1.5: Correction Epilogue (TMEM Round-Trip Error) + O Rescale
Issue 1: TMEM round-trip error. Hand-constructed Ld32x32bOp/St32x32bOp atoms don't preserve the C-fragment layout during TMEM round-trips (load→modify→store). Causes ~3% error per round-trip.
Current workaround: Kernel outputs un-normalized O + LSE. No in-kernel normalization needed. External normalization is exact.
Proper fix (future): Use CUTLASS epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition pattern with paired atoms. One-way trip: TMEM → registers (normalize) → SMEM → GMEM.
Priority: MEDIUM. Not a correctness blocker (external normalization is exact). Would enable in-kernel normalization for D5c/d. Also blocks NVFP4-1.2 (inverse RoPE FP4 fuse).
Issue 2: O rescale for kt>0 (multi-KV-tile). CONFIRMED BROKEN (May 24). Even a NO-OP round-trip (load O, multiply by 1.0, store back) produces cos 0.804 at s_k=256. The Ld32x32bOp/St32x32bOp atoms corrupt data regardless of the rescale factor. The same atoms in CUTLASS correction_rescale use the same pattern — unclear why theirs works with 12-warp layout but ours fails with 6-warp.
Workaround (VERIFIED): Python KV merge with per-segment LSE. Run kernel with s_k=128 (1 KV tile, no rescale) per segment. Merge using: O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]. Verified cos 0.999998 for s_k=256, 384, 512, 1024 at hd=64. Caveat: requires per-row LSE output (currently only row 0 is written; per-row verified correct with max err 0.000001 but CuTe tensor indexing needs work for full per-row output).
Priority: HIGH for production (DSV4 Pro needs s_k=1024). The Python merge works but adds kernel launch overhead (8 launches for s_k=1024). Fused in-kernel rescale requires fixing the TMEM round-trip or using a different accumulator strategy.
Build Order (Remaining)
D1.4 — hd=512 ⚡ CURRENT (BLOCKED)
Problem: hd=512 exceeds the MMA instruction's max K-dim (256). Must split Q and K into 2 sub-tiles along head_dim (k_tile=256, n_k_sub_tiles=2). The QK dot product is S = Q_k0 @ K_k0^T + Q_k1 @ K_k1^T (additive in logit space).
SMEM budget: SOLVED. pv_n_tile=128 for hd>256 reduces sV from 64KB→32KB, sC from 64KB→32KB. Total 192KB ✅.
Compilation: BLOCKED. The CuTeDSL MLIR→PTX backend optimizer cannot compile the hd=512 kernel in reasonable time. Both Python range() (unrolled IR) and cutlass.range(unroll=1) (runtime loop) produce IR that the optimizer chews on for 3+ hours without finishing. The Python tracer completes in 0.8s — the kernel is structurally correct. This is a toolchain limitation.
External merge: IMPOSSIBLE. The D5 online softmax merge formula assumes separate attention distributions over different token sets (additive in weight space). k_sub segments are additive in LOGIT space (S = S_0 + S_1), not weight space. You cannot recover softmax(S_0 + S_1)@V from softmax(S_0)@V and softmax(S_1)@V. In-kernel accumulation before softmax is the only correct approach.
Bug fixes applied along the way:
- LSE type mismatch (BF16 vs FP32 when normalize=True) → guarded with
const_expr(not self.normalize) - O rescale IR explosion at n=128 → guarded with
const_expr(n_kv_tiles > 1) - k_sub tracer IR explosion → replaced hardcoded
if k_sub==0/1with Pythonrange()loop - External merge test (cos 0.617) → confirmed mathematically impossible, deleted approach
Possible paths forward (priority order):
- Pre-compile hd=512 kernel offline. Accept 1-2 hour compilation during build. Cache the cubin. This works if the MLIR optimizer eventually finishes (it might just be slow, not stuck — but 3+ hours is excessive even for pre-compilation).
- Add no-softmax mode to the kernel. Output raw S (QK scores) without softmax. Call twice for k_sub=0 and k_sub=1. Accumulate S_0+S_1 in Python. Apply softmax once. This requires modifying the softmax warp to optionally skip normalization and output S to GMEM instead of P to TMEM/SMEM.
- Write hd=512 kernel in CUTLASS C++. Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds).
- Report CuTeDSL MLIR optimizer bug. The optimizer should handle this IR in reasonable time. File an issue with NVIDIA.
D2 — Multi-Query Grid with Head Packing
- Grid changes from
(1, 1, 1)to(num_q_blocks, 1, batch) - DSV4 is MQA: all 128 query heads share same K/V
- Head axis folded into M dimension:
M_tile = 128coversM = T * n_hrows
D3 — SWA Sequence Length Mask
- Add
swa_lens: [batch] int32kernel input - Mask SWA-branch logits to
-infwhereswa_idx >= swa_lens[b]
D4 — Causal Mask on SWA Branch
- Add
is_causal: boolconstructor flag - Apply
swa_idx > q_posmasking to-infin SWA pass
D5 — SWA + Sink Merge
- D5a ✅: Kernel outputs un-normalized O + LSE
- D5b ✅: Python merge works (cos 0.961 at hd=64)
- D5c: Fuse two passes into one kernel launch
- D5d: Fuse sink merge into kernel epilogue
NVFP4 Precision Roadmap (May 23, 2026)
Three honest buckets. A fourth speculative bucket flagged at the end.
NVFP4-0: Verify Right Blackwell FP4 Primitives ⚡ DO FIRST
No correctness or quality risk. Pure correctness of implementation. If these are wrong, we're running wrong MMA shapes silently.
NVFP4-0.1 — sf_dtype tracing
What: Trace the SF dtype through the full pipeline: gemm_runner.py → dense.py → blockscaled_utils → TMEM layout.
The problem: dense.py line 137 says NVF4 supports Float8E8M0FNU/Float8E4M3FN at sf_vec_size=16. But UE8M0 is the MXFP4/MXFP8 scale format. NVFP4 uses FP8 E4M3. The examples on lines 90/100 show Float8E8M0FNU at sf_vec_size=16 which is the MXFP4 path. Need to verify the runner is passing E4M3, not E8M0.
Action:
- Print
sf_dtypeingemm_runner.pyat construction:print(f"sf_dtype={sf_dtype}, sf_vec_size={SF_VEC_SIZE}") - Print
self.sf_dtypeindense.pyBlockScaledGEMM.__init__ - Print
self.sf_vec_sizeindense.py - Trace through
blockscaled_utils.make_sm100_sf_layout— does it produce E4M3 packing (4 FP8 E4M3 → 1 int32) or UE8M0 packing? - If wrong sf_dtype is found: fix in
gemm_runner.pySF_DTYPE constant, retest MoE cosine
NVFP4-0.2 — SF TMEM layout verification
What: NVFP4 expects scale factors in TMEM in a specific transposed-packed layout. UE4M3 for NVFP4 (4 packed FP8 E4M3 per int32 word). The comment in dense.py about "SM100 requires scaling factors in packed UE8M0 format" is for MXFP8, not NVFP4.
Action:
- Print TMEM scale-factor offsets at GEMM construction:
print(f"sf_smem_layout={sf_smem_layout}, sf_tmem_offset={sf_tmem_offset}") - Verify the packing matches UE4M3 (NVFP4) not UE8M0 (MXFP8)
- Trace
blockscaled_utils.make_sm100_sf_layoutand print the output layout - If wrong packing: fix
make_sm100_sf_layoutor add NVFP4-specific layout path
NVFP4-0.3 — FP4 TMA element type
What: float4_e2m1fn_x2 must survive all the way into TMA descriptor creation. Blackstone TMA supports e2m1_x2 packed-FP4 element type directly. Loading as uint8 works but loses tensor-core awareness.
Action:
- Trace
float4_e2m1fn_x2throughquantize.py→ TMA atom creation infmha.py - Print the GMEM tensor dtype at FMHA kernel input
- Print the TMA atom dtype at construction
- Verify
cpasync.tma_partitionreceivesfloat4_e2m1fn_x2element type, not uint8 - If uint8 fallback: fix TMA atom creation in
fmha.py
NVFP4-0.4 — MMA kind is mxf4nvf4
What: Blackwell has a single MMA kind for both MXFP4 and NVFP4. NVFP4 = scales are FP8 E4M3, 16-element block. MXFP4 = scales are UE8M0, 32-element block. The MMA kind is determined by scale-factor type at runtime. Need to confirm tcgen05 is inferring NVFP4.
Action:
- Print
tcgen05.mma.kindat GEMM construction (if accessible) - Print the MMA instruction shape
(M, N, K)confirmed by JIT compile - Verify it matches Blackwell MMA shape for NVFP4 (not MXFP4)
Execution: These are 5-minute print jobs. Do all 4 NVFP4-0 items before touching any code. If any of them reveals a wrong dtype, fix it FIRST before anything else. A wrong sf_dtype poisons every FP4 GEMM result.
NVFP4-1: Eliminate BF16 Round-Trips After FP4 GEMMs 🔴 PURE-WIN, NO QUALITY RISK
These are pure bandwidth/compute wins. The math doesn't change — we just avoid precision loss and kernel launch overhead.
NVFP4-1.1 — Fuse FP4 quant into SwiGLU epilogue (MoE L1 → L2)
What: Current MoE forward:
padded_x_fp4 → L1 GEMM → SwiGLU → BF16 GMEM ← LEAK
quantize_activation_nvfp4 ← SEPARATE KERNEL
padded_activated_fp4 → L2 GEMM → BF16 GMEM
Paper §4.2.2: NVFP4 weights. L1 → SwiGLU → online amax → FP8 scale + FP4 pack → FP4 GMEM → L2 GEMM.
The amax reduction is in-registers: for an epi tile with 16 contiguous elements per thread, each tile produces one FP8 E4M3 scale and 64 bits of packed FP4 nibbles. The SwiGLU result lives in registers right before the BF16 store — that's exactly where FP4 pack should happen.
What you save:
- ~2× GMEM bandwidth between L1 and L2 (FP4 instead of BF16)
- Entire
quantize_activation_nvfp4kernel launch padded_activated_fp4/padded_activated_x_sfscratch buffers- GPU-side amax computation (runs on tensor cores vs scalar)
- L2 scale-factor TMA reads FP8 scales L1 just produced
How: Extend the fused SwiGLU epilogue. After computing gate * up:
- Compute per-16-element amax across the subtile (all-reduce or butterfly shfl_xor)
- Compute FP8 E4M3 scale = amax / 448 (E4M3 max)
- Pack each element:
sign_bit << 7 | (clamped_val / scale).to(uint4) - Write packed nibbles to GMEM as
float4_e2m1fn_x2 - Write FP8 scale to SF TMA buffer
The amax subtlety: For NVFP4 the microblock is 16 elements. Port the same 16-element logic from quantize.py into the epilogue. Do NOT use 32-element MXFP4 microblocks.
- Extend
dsv4/kernels/gemm/fused_swiglu.pyepilogue to FP4 pack SwiGLU output - Add per-subtile amax reduction (register-only, no extra kernel)
- Verify: L1 → L2 cosine matches reference (no regression from BF16 intermediate)
- Verify: L2 GEMM reads FP4 scales produced by L1 epilogue
- Test: MoE layer output cosine with full L1→L2 pipeline
- Scope: MoE-side, NOT fmha.py. Does not block FMHA D1.
NVFP4-1.2 — Fuse FP4 quant into inverse RoPE → wo_a path
What: inverse_rope_bf16 produces BF16, then wo_a quantizes it. Fuse FP4 quant into inverse RoPE epilogue.
Same pattern as NVFP4-1.1: after inverse RoPE rotation, compute amax → FP8 scale → FP4 pack → FP4 GMEM. The wo_a GEMM reads FP4 + scales.
- Extend
dsv4/ops/rope.pyinverse RoPE to emit FP4 instead of BF16 - Wire
wo_aGEMM to read FP4 scales from inverse RoPE output - Test: attention sub-block output cosine (full inverse RoPE → wo_a → attention)
NVFP4-1.3 — Fuse FP4 quant into mHC mixing → attention/FFN input
What: B_l @ X_l + C_l ⊗ F_out (mHC post_block) lands in BF16. Attention's q_down and FFN's L1 GEMM quantize it. Fuse quant into mHC mixing kernel.
Same pattern. After mHC mixing post-compute, amax → FP8 scale → FP4 pack → FP4 GMEM. Attention and FFN GEMMs read FP4.
- Add FP4 epilogue to mHC mixing kernel (when building it)
- Wire attention
q_downand FFN L1 to read mHC FP4 output - Test: end-to-end layer cosine (mHC → attention → FFN)
Note: NVFP4-1.2 and NVFP4-1.3 depend on D1.5 (correction epilogue fix) because those epilogues need the clean one-way TMEM path. NVFP4-1.1 (MoE SwiGLU) is independent.
NVFP4-2: FP4 KV Pipeline Depth in FMHA 🔴 STAGE D, DEPENDS ON D1
FP4 KV shrinks tiles 4×, same SMEM budget buys 3× more pipeline stages.
| KV dtype | Tile size (hd=512) | 2 stages | 4 stages | 6 stages |
|---|---|---|---|---|
| BF16 | 128 KB (K+V) | 512 KB ✅ | — | — |
| FP8 | 64 KB (K+V) | 256 KB ✅ | 512 KB | — |
| FP4 | ~36 KB (K+V) | 144 KB ✅ | 288 KB | 432 KB |
Each extra stage hides more TMA latency. At 1M-context decode where KV reads dominate, deeper pipelines are a major perf win.
Implementation:
- After D1 (SMEM-P works with BF16): add FP4 TMA load + SMEM dequant path
- TMA loads FP4 NoPE dims (packed e2m1_x2) to SMEM slot 0
- TMA loads BF16 RoPE dims to SMEM slot 1
- TMA loads FP8 scale factors to SMEM slot 2
- Dequantize FP4→BF16 in SMEM (vectorized
* FP8_scale, 16-element microblocks) - Concatenate [NoPE, RoPE] in SMEM
- MMA reads contiguous BF16 from SMEM
- Prerequisite: D1 working at BF16 first. Cannot skip.
- Test: FP4+BF16 split input → identical output to pure BF16 input (dequant is transparent)
NVFP4-3: use_2cta_instrs for Production MoE 🟢 30 MINUTES, PURE PERF
This is the single biggest single-knob perf win for FP4 GEMMs on B200.
What: FusedSwiGLUScaledGroupedGemmKernel supports 2-CTA UMMA but defaults to False. With 2-CTA, the B operand is TMA-multicast: each CTA reads half of B, peers cross the Infiniband link. Effective MMA tile M doubles (128→256, 256→512).
Measured win: 1.7–1.9× throughput over single-CTA at prefill/batch shapes.
Decision tree:
- M < 128 (decode single-token): 1-CTA is correct. 2-CTA wastes hardware.
- M ≥ 256 (prefill or batched decode): 2-CTA is free perf.
- cluster_m must be even for 2-CTA.
Action:
- Add conditional:
use_2cta_instrs = (M >= 256 and cluster_m % 2 == 0) - Default stays
False(correct for decode) - Python GEMM runner sets
use_2cta_instrs=Truefor prefill shapes - Test: throughput comparison at M=256, 512, 1024
- Scope: MoE-side,
gemm_runner.py. Does not affect FMHA.
⚠️ Speculative: Beyond V4 Paper Validation
The following are real potential wins but go beyond what the V4 paper explicitly validated for FP4. Listed for completeness, do NOT implement without explicit sign-off from Mike.
-
Indexer FP4 tensor-core scoring (paper §5.2.1 "QK path in the indexer... cached, loaded, and multiplied entirely in FP4")
- Paper says the indexer SHOULD do QK in FP4 with tensor cores
- Current: scalar FP32 dot products with no tensor cores
- Huge scope: 2-3 weeks minimum
- Risk: FP4 dot product precision for index selection needs recall validation
- Verdict: Track for Stage F. Do NVFP4-0.4 first.
-
MXFP4 vs NVFP4 for indexer scoring — not validated in the paper for indexer specifically. Evaluate after NVFP4-0.
-
NVFP4 for full attention Q×K^T GEMM — Already closed. NVFP4 Q×K^T is too lossy (cos 0.86 vs FP32). Attention stays FP16/FP32.
-
Per-token FP8 activation scaling in FMHA — Different precision model, not validated. Out of scope.
NVFP4 Execution Order
| # | Task | Scope | Risk | Blocks | Est. |
|---|---|---|---|---|---|
| NVFP4-0.1 | sf_dtype tracing | Both | NONE — print only | D1 if wrong | 5 min |
| NVFP4-0.2 | SF TMEM layout | Both | NONE — print only | D1 if wrong | 5 min |
| NVFP4-0.3 | FP4 TMA element type | FMHA | NONE — print only | D1 if wrong | 5 min |
| NVFP4-0.4 | MMA kind verification | GEMM | NONE — print only | everything | 5 min |
| NVFP4-3 | use_2cta_instrs conditional | MoE | NONE — perf only | nothing | 30 min |
| NVFP4-1.1 | Fuse FP4 quant into SwiGLU epilogue | MoE | NONE | nothing | 1 day |
| NVFP4-1.2 | Fuse FP4 quant into invRoPE→wo_a | Attention | NONE | D1.5 | 1 day |
| NVFP4-1.3 | Fuse FP4 quant into mHC mixing | Attention | NONE | post-D5 | 2 days |
| D1.5 | Correction epilogue fix | FMHA | MEDIUM | NVFP4-1.2 | 2-3 hours |
| NVFP4-2 | FP4 KV pipeline depth | FMHA | NONE — perf only | D1 | 1 day |
NVFP4-0 results gate the critical path. If NVFP4-0.1–0.4 find a wrong sf_dtype, the fix comes before D2. Everything else is either parallel or post-D1.
NVFP4-3 (use_2cta_instrs) is the fastest win and has no dependencies. Do it immediately after the NVFP4-0 prints.
NVFP4-1.1 (fuse FP4 into SwiGLU) is the next-biggest win. No FMHA dependency. Do it in parallel with D2.
NVFP4-2 (FP4 KV) depends on D1 being solid. Do after D2 or alongside hd=512 fix.