Files
nvfp4-megamoe-kernel/FINAL_STRETCH.md
2026-06-02 22:31:13 +00:00

6.8 KiB
Raw Permalink Blame History

DSV4 Audit — Decode Repetition + Precision / Tensor-Core Plan

PART B — Precision / NVFP4 / tensor-core (WE ARE SKIPPING PART A FOR RIGHT NOW AND WILL REVISIT IT)

Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 only where required. Validate each change with per-layer cosine vs dsv4/reference before trusting it.

B0 — What's already optimal: DO NOT "fix" the MoE

dsv4/layers/moe.py already runs native NVFP4: expert weights and activations are float4_e2m1fn_x2, block scales are float8_e4m3fn. This matches the paper (routed experts in FP4). Leave it. The remaining wins are in attention and the indexer, not MoE.

P5 — Fused mHC pre_block + RMSNorm + NVFP4 quantize: DONE

  • fused_mhc_rmsnorm_quantize.cu — 2-kernel approach (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4)
  • Integrated into forward_layer for BOTH attn and ffn mHC paths (commit 0b6ca0d)
  • Replaces: pre_block bmm (1 launch) + rmsnorm (4+ launches) + quantize (2 launches) → 2 launches
  • Savings: ~5 launches/site × 2 sites × 61 layers = 610 launches/token
  • Unit test: cos=0.999 vs unfused, 0.995 vs true mHC+RMSNorm at T=1/8/128
  • gsa per-row diff: ~1-2e-6 (excellent)

P4 — Fused RMSNorm + NVFP4 quantize: DONE

  • fused_rmsnorm_quantize.cu — 2-kernel approach
  • Integrated for standalone rmsnorm+quantize paths
  • gsa scalar fix in Nvfp4Linear.run_from_quantized: per-row gsa reduced to scalar (max) for GEMM compatibility

Stale Lock Fix: DONE (commit 845227c)

  • dsv4/kernels/cuda/loader.py: _cleanup_stale_lock() removes lock files older than 10 minutes
  • Prevents infinite spin after crash/kill during CUDA kernel compilation

B1 — FP8_E4M3 FMHA (BIG win; perf + memory + native Blackwell)

Today: KV is stored mixed (FP8 nope + BF16 rope), then in "5. Gather KV" it's dequantized to BF16 into gbuf, and the FMHA runs in BF16. That throws away the FP8 you stored and runs the heaviest kernel at half the tensor-core throughput Blackwell offers.

NVFP4 KV is correctly ruled out — your own KVCache docstring shows 4-bit KV values cost ~0.4%/round-trip that compounds fatally over 61 layers. FP8_E4M3 is the right target, and you already store the nope dims in it. Plan:

  • Feed FP8 nope dims to the FMHA directly (skip the FP8→BF16 dequant in comp_nope_selective/comp_nope_all). Keep the 64 rope dims in BF16 (precision-sensitive) → a split-precision FMHA, or quantize rope to FP8 too and measure cos.
  • Quantize q to FP8 before the FMHA (it's BF16 now; see B3). Blackwell FP8 MMA consumes FP8×FP8.
  • Wins: removes the per-entry dequant, halves gbuf bandwidth (the per-step gather is on the decode hot path), and uses FP8 tensor cores. The DeepGEMM reference fp8_mqa_logits / FP8 attention paths are the template.
  • Gate it behind a cos check vs the BF16 FMHA per layer; if rope-in-FP8 drops cos, keep rope BF16.
  • DeepGemm will probably show E4M3 for forward passes and E5M2 for gradients, which is correct

B2 — Indexer scoring on FP8/FP4 tensor cores (BIG at long context; native FP4)

single_shot_inference.py indexer scoring is torch.einsum('tnd,cd->tnc', q_idx.float(), k_idx.float())full FP32 einsum on CUDA cores over all n_comp entries, every CSA layer, every decode step. At long context this is the dominant indexer cost and it's the opposite of native-FP4. The indexer keys are already FP8 in cache. Replace with a tensor-core weighted-ReLU MQA-logits kernel in FP8 (or FP4 for the QK path, as the paper does: "lightning indexer ... FP4"). Mirror DeepGEMM fp8_fp4_mqa_logits. This is both the long-context perf unlock and a native-FP4 conversion. (The dead dsv4/kernels/indexer/*.cu is not this — write it fresh against the DeepGEMM kernel, score in FP8/FP4, top-k with a warp-local reduction, no global lock.)

B3 — Fused rmsnorm→quant for q_a_norm / kv_norm (small, removes BF16 round-trips)

  • DONE: q_a_normq_b path now uses fused rmsnorm_quantize_nvfp4 + run_from_quantized (commit 0b6ca0d)
  • Skips BF16 materialization between q_a_norm and q_b GEMM
  • Saves ~6 kernel launches per layer
  • kv_norm still uses unfused rmsnorm — requires FP8 FMHA (B1) to fully benefit, since kv goes to RoPE not another GEMM

B4 — General "producer BF16 → consumer FP32" sweep (the user's pattern)

Find and fix places that cast up immediately after producing a narrower dtype:

grep -nE "\.float\(\)" single_shot_inference.py dsv4/layers/*.py dsv4/ops/*.py

For each hit, check the producing line just above. The rule: emit the dtype the next consumer needs. Two directions:

  • Producer makes BF16, consumer's first act is .float() → make the producer emit FP32 (or fuse), skip the cast.
  • Producer makes FP32 only to be quantized to FP4/FP8 next → fuse the quant into the producing kernel (as B3). Do not apply this to the compression boundaries: the compressor should emit FP32 then downcast to FP8/BF16 for storage — that downcast is the architecture's memory budget, not a wasted step.

B5 — Residual-stream precision (low priority; only if A-items don't fully resolve degeneration)

The mHC residual X is BF16 at |X|≈300, where BF16 ULP ≈ 2. This is probably fine (matches the reference / paper's expected magnitude, and mHC's doubly-stochastic B is non-expansive). But if late-decode degeneration survives Part A, A/B test the residual stream in FP32 for a few layers and watch whether the repetition onset moves. If it does, the residual precision is a contributor; if not, rule it out. Keep this last — FP32 residual doubles mHC activation memory/bandwidth, against the concurrency goal.


PART C — Guardrails for the agent

  1. Every precision change is gated by a per-layer cosine vs dsv4/reference for a fixed prompt, before judging end-to-end output. Record the cos in the commit message.
  2. One change per commit, with the A/B result. If a change drops end-to-end coherence, the per-layer cos tells you which layer/op regressed.
  3. Don't re-create the dead indexer. B2 is a new FP8/FP4 kernel; the dsv4/kernels/indexer/*.cu files are archived/dead — confirm with helpers/import_closure.py before reusing anything there.
  4. Re-validate the stop fix (A1) on a long generation (≥512 tokens) and a multi-turn prompt, not just "capital of France" — the turn-end token differs by prompt type.

Suggested sequence

B1 (FP8 FMHA) → B2 (FP8/FP4 indexer) → B3 (fused norm+quant) → B4 (cast sweep) → B5 only if needed.


PART D — Dangling TODOS

  • It is mentioned in /home/openclaw/dev/nvfp4-megamoe-kernel/docs/PERFORMANCE_AUDIT.md that P5 (Fuse mHC pre_block + RMSNorm into a single op) is done but kernel, pending integration. Please wire that up if you have not done so already

  • Batched Prefill. Did we ever do this???