Files
nvfp4-megamoe-kernel/FINAL_STRETCH.md
biondizzle 75288bd12f Wire prefill FMHA into production.py and single_shot
- Add dsv4_attention_mixed_fp8_prefill to production.py
- _run_production_fmha_mixed now dispatches to prefill kernel for T>1
- Remove decode-only T==1 restriction
- Update FINAL_STRETCH.md: prefill marked DONE, batched prefill TODO noted
2026-06-03 03:49:57 +00:00

5.4 KiB
Raw Blame History

DSV4 Audit — Decode Repetition + Precision / Tensor-Core Plan

PART B — Precision / NVFP4 / tensor-core (WE ARE SKIPPING PART A FOR RIGHT NOW AND WILL REVISIT IT)

Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 only where required. Validate each change with per-layer cosine vs dsv4/reference before trusting it.

B0 — What's already optimal: DO NOT "fix" the MoE

dsv4/layers/moe.py already runs native NVFP4: expert weights and activations are float4_e2m1fn_x2, block scales are float8_e4m3fn. This matches the paper (routed experts in FP4). Leave it. The remaining wins are in attention and the indexer, not MoE.

P5 — Fused mHC pre_block + RMSNorm + NVFP4 quantize: DONE

  • fused_mhc_rmsnorm_quantize.cu — 2-kernel approach (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4)
  • Integrated into forward_layer for BOTH attn and ffn mHC paths (commit 0b6ca0d)
  • Unit test: cos=0.999 vs unfused, 0.995 vs true mHC+RMSNorm at T=1/8/128

P4 — Fused RMSNorm + NVFP4 quantize: DONE

  • fused_rmsnorm_quantize.cu — 2-kernel approach
  • gsa scalar fix in Nvfp4Linear.run_from_quantized

Stale Lock Fix: DONE (commit 845227c)

B1 — FP8_E4M3 FMHA: DONE

Implementation: dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh + C API + Python bridge.

Storage-native DSV4 attention: noPE KV stays FP8_E4M3, RoPE KV stays BF16, no global FP8→BF16 dequant.

Unit Test Results (2026-06-03, tests/unit/test_b1_mixed_fp8_fmha.py)

Test Status
quantize_q_fp8_split PASS (cos=0.9997)
gather_mixed kernels PASS
FMHA cosine (N=128..2048, H=128) PASS (cos=0.9999..0.9997)
Attention sinks PASS
GQA/MQA (128 Q heads) PASS
Weight loading verification PASS
Batch sizes (B=1,2,4) PASS

Bugs Found and Fixed

  1. V matrix canonical layout swap (commit 4fe7f9d): canon_idx_bf16_16x16(kk, dd) was wrong — should be canon_idx_bf16_16x16(dd, kk). The SMEM group structure was transposed vs the working TMA-loaded V in the multitile kernel. This caused cos=0.158 vs BF16 reference. After fix: cos=0.999972 at N=128.

Known Limitations

  • Prefill batch size: T=1..128 supported. For T>128, caller must split. T_BATCH=32 sub-batches used internally.
  • Specialized for DSV4 HD=512/NOPE=448/ROPE=64.

Bug Fix (2026-06-03)

  1. CRITICAL: T-dimension strides were wrong for T>1 — the kernel used q_nope_head_stride (stride(1) = T*NOPE) for the T dimension, but the correct stride is stride(2) = NOPE. For T=1 this is invisible (qr=0 always), but for T>1 it reads garbage from adjacent heads' data. Fix: added explicit T-dimension strides (q_nope_t_stride, q_scale_t_stride, q_rope_t_stride) to params struct, C API, and Python wrapper. All 16 T>1 test configs now pass (cos >= 0.999887).

B2 — FP8 tensor-core indexer scoring: DONE

Implementation: dsv4/kernels/cuda/indexer_fp8_score_topk.cu

Native Blackwell FP8 GEMM via tcgen05 for CSA Lightning Indexer scoring. No PyTorch einsum fallback.

Unit Test Results (2026-06-03, tests/unit/test_b2_indexer_fp8.py)

Test Status
Score cosine vs FP32 reference (n_comp=128..8192) PASS (100% overlap ≤1024, ~88% at 8192)
Score distribution sanity PASS
Determinism PASS
Edge cases (n_comp < top_k, n_comp=1) PASS
Weight format verification PASS

Bugs Found and Fixed

  1. Broken 16x256b.x1 TMEM read — instruction was hanging. Root cause: the 16x256b.x1 PTX instruction either doesn't exist on SM100 or has different alignment requirements. Fix: use the proven 32x32b.x8 instruction from B1 FMHA.

  2. TMEM_COLS too small — TMEM_COLS=128 was insufficient for the 128×128 MMA output. The MMA writes ALL 128 rows, requiring 4 row-groups × 128 columns = 512 TMEM columns. Fix: TMEM_COLS=512.

  3. Wrong TMEM offset for rows 32-63 — tried tb + SK_TILE + col_base and tb + 16 + col_base, both gave wrong results. Root cause: the 32x32b.x8 instruction maps different warps to different row slices from the SAME TMEM address. Warp 0 reads rows 0-31, warp 1 reads rows 32-63, all from tb + col_base. Fix: warps 0-1 both read from the same address, accumulate into separate SMEM partitions, then merge.

  4. Cross-warp accumulation race condition — initial attempt used shared sLogits[c] with first-warp-writes/second-warp-adds pattern, which was non-deterministic. Fix: per-warp score partitions (sWarpScores[0..SK_TILE-1] and sWarpScores[SK_TILE..2*SK_TILE-1]), merged after __syncthreads().

Production Configuration

  • n_ih=64, ihd=128, top_k=1024
  • Warps 0-1: TMEM read + per-warp score accumulation
  • Warp 4: MMA (FP8 GEMM)
  • Per-thread local top-k (INDEXER_LOCAL_K=8) → block-level merge

B3 — Fused rmsnorm→quant for q_a_norm / kv_norm: DONE

  • q_a_normq_b path uses fused rmsnorm_quantize_nvfp4 + run_from_quantized
  • kv_norm still uses unfused rmsnorm — requires FP8 FMHA (B1) to fully benefit

B4 — General "producer BF16 → consumer FP32" sweep: NOT STARTED

B5 — Residual-stream precision: NOT STARTED (low priority)


PART D — Dangling TODOS

  • Batched Prefill: DONE (T=1..128, mixed FP8/BF16 kernel)
  • Need to wire prefill into single_shot_inference.py (replace T=1 token-by-token prefill)
  • Need T>128 support (split into multiple launches)