- Add dsv4_attention_mixed_fp8_prefill to production.py - _run_production_fmha_mixed now dispatches to prefill kernel for T>1 - Remove decode-only T==1 restriction - Update FINAL_STRETCH.md: prefill marked DONE, batched prefill TODO noted
5.4 KiB
DSV4 Audit — Decode Repetition + Precision / Tensor-Core Plan
PART B — Precision / NVFP4 / tensor-core (WE ARE SKIPPING PART A FOR RIGHT NOW AND WILL REVISIT IT)
Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 only where required. Validate each change with per-layer cosine vs dsv4/reference before trusting it.
B0 — What's already optimal: DO NOT "fix" the MoE
dsv4/layers/moe.py already runs native NVFP4: expert weights and activations are float4_e2m1fn_x2, block scales are float8_e4m3fn. This matches the paper (routed experts in FP4). Leave it. The remaining wins are in attention and the indexer, not MoE.
P5 — Fused mHC pre_block + RMSNorm + NVFP4 quantize: ✅ DONE
fused_mhc_rmsnorm_quantize.cu— 2-kernel approach (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4)- Integrated into
forward_layerfor BOTH attn and ffn mHC paths (commit0b6ca0d) - Unit test: cos=0.999 vs unfused, 0.995 vs true mHC+RMSNorm at T=1/8/128
P4 — Fused RMSNorm + NVFP4 quantize: ✅ DONE
fused_rmsnorm_quantize.cu— 2-kernel approach- gsa scalar fix in
Nvfp4Linear.run_from_quantized
Stale Lock Fix: ✅ DONE (commit 845227c)
B1 — FP8_E4M3 FMHA: ✅ DONE
Implementation: dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh + C API + Python bridge.
Storage-native DSV4 attention: noPE KV stays FP8_E4M3, RoPE KV stays BF16, no global FP8→BF16 dequant.
Unit Test Results (2026-06-03, tests/unit/test_b1_mixed_fp8_fmha.py)
| Test | Status |
|---|---|
| quantize_q_fp8_split | ✅ PASS (cos=0.9997) |
| gather_mixed kernels | ✅ PASS |
| FMHA cosine (N=128..2048, H=128) | ✅ PASS (cos=0.9999..0.9997) |
| Attention sinks | ✅ PASS |
| GQA/MQA (128 Q heads) | ✅ PASS |
| Weight loading verification | ✅ PASS |
| Batch sizes (B=1,2,4) | ✅ PASS |
Bugs Found and Fixed
- V matrix canonical layout swap (commit
4fe7f9d):canon_idx_bf16_16x16(kk, dd)was wrong — should becanon_idx_bf16_16x16(dd, kk). The SMEM group structure was transposed vs the working TMA-loaded V in the multitile kernel. This caused cos=0.158 vs BF16 reference. After fix: cos=0.999972 at N=128.
Known Limitations
- Prefill batch size: T=1..128 supported. For T>128, caller must split. T_BATCH=32 sub-batches used internally.
- Specialized for DSV4 HD=512/NOPE=448/ROPE=64.
Bug Fix (2026-06-03)
- CRITICAL: T-dimension strides were wrong for T>1 — the kernel used
q_nope_head_stride(stride(1) = T*NOPE) for the T dimension, but the correct stride isstride(2) = NOPE. For T=1 this is invisible (qr=0 always), but for T>1 it reads garbage from adjacent heads' data. Fix: added explicit T-dimension strides (q_nope_t_stride,q_scale_t_stride,q_rope_t_stride) to params struct, C API, and Python wrapper. All 16 T>1 test configs now pass (cos >= 0.999887).
B2 — FP8 tensor-core indexer scoring: ✅ DONE
Implementation: dsv4/kernels/cuda/indexer_fp8_score_topk.cu
Native Blackwell FP8 GEMM via tcgen05 for CSA Lightning Indexer scoring. No PyTorch einsum fallback.
Unit Test Results (2026-06-03, tests/unit/test_b2_indexer_fp8.py)
| Test | Status |
|---|---|
| Score cosine vs FP32 reference (n_comp=128..8192) | ✅ PASS (100% overlap ≤1024, ~88% at 8192) |
| Score distribution sanity | ✅ PASS |
| Determinism | ✅ PASS |
| Edge cases (n_comp < top_k, n_comp=1) | ✅ PASS |
| Weight format verification | ✅ PASS |
Bugs Found and Fixed
-
Broken
16x256b.x1TMEM read — instruction was hanging. Root cause: the16x256b.x1PTX instruction either doesn't exist on SM100 or has different alignment requirements. Fix: use the proven32x32b.x8instruction from B1 FMHA. -
TMEM_COLS too small — TMEM_COLS=128 was insufficient for the 128×128 MMA output. The MMA writes ALL 128 rows, requiring 4 row-groups × 128 columns = 512 TMEM columns. Fix: TMEM_COLS=512.
-
Wrong TMEM offset for rows 32-63 — tried
tb + SK_TILE + col_baseandtb + 16 + col_base, both gave wrong results. Root cause: the32x32b.x8instruction maps different warps to different row slices from the SAME TMEM address. Warp 0 reads rows 0-31, warp 1 reads rows 32-63, all fromtb + col_base. Fix: warps 0-1 both read from the same address, accumulate into separate SMEM partitions, then merge. -
Cross-warp accumulation race condition — initial attempt used shared
sLogits[c]with first-warp-writes/second-warp-adds pattern, which was non-deterministic. Fix: per-warp score partitions (sWarpScores[0..SK_TILE-1]andsWarpScores[SK_TILE..2*SK_TILE-1]), merged after__syncthreads().
Production Configuration
- n_ih=64, ihd=128, top_k=1024
- Warps 0-1: TMEM read + per-warp score accumulation
- Warp 4: MMA (FP8 GEMM)
- Per-thread local top-k (INDEXER_LOCAL_K=8) → block-level merge
B3 — Fused rmsnorm→quant for q_a_norm / kv_norm: ✅ DONE
q_a_norm→q_bpath uses fusedrmsnorm_quantize_nvfp4+run_from_quantizedkv_normstill uses unfused rmsnorm — requires FP8 FMHA (B1) to fully benefit
B4 — General "producer BF16 → consumer FP32" sweep: NOT STARTED
B5 — Residual-stream precision: NOT STARTED (low priority)
PART D — Dangling TODOS
- Batched Prefill: ✅ DONE (T=1..128, mixed FP8/BF16 kernel)
- Need to wire prefill into single_shot_inference.py (replace T=1 token-by-token prefill)
- Need T>128 support (split into multiple launches)