Files
nvfp4-megamoe-kernel/FINAL_STRETCH.md
biondizzle 75288bd12f Wire prefill FMHA into production.py and single_shot
- Add dsv4_attention_mixed_fp8_prefill to production.py
- _run_production_fmha_mixed now dispatches to prefill kernel for T>1
- Remove decode-only T==1 restriction
- Update FINAL_STRETCH.md: prefill marked DONE, batched prefill TODO noted
2026-06-03 03:49:57 +00:00

97 lines
5.4 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# DSV4 Audit — Decode Repetition + Precision / Tensor-Core Plan
# PART B — Precision / NVFP4 / tensor-core (WE ARE SKIPPING PART A FOR RIGHT NOW AND WILL REVISIT IT)
Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 only where required. Validate each change with per-layer cosine vs `dsv4/reference` before trusting it.
## B0 — What's already optimal: DO NOT "fix" the MoE
`dsv4/layers/moe.py` already runs **native NVFP4**: expert weights and activations are `float4_e2m1fn_x2`, block scales are `float8_e4m3fn`. This matches the paper (routed experts in FP4). Leave it. The remaining wins are in **attention** and the **indexer**, not MoE.
### P5 — Fused mHC pre_block + RMSNorm + NVFP4 quantize: ✅ DONE
- `fused_mhc_rmsnorm_quantize.cu` — 2-kernel approach (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4)
- **Integrated into `forward_layer`** for BOTH attn and ffn mHC paths (commit 0b6ca0d)
- Unit test: cos=0.999 vs unfused, 0.995 vs true mHC+RMSNorm at T=1/8/128
### P4 — Fused RMSNorm + NVFP4 quantize: ✅ DONE
- `fused_rmsnorm_quantize.cu` — 2-kernel approach
- gsa scalar fix in `Nvfp4Linear.run_from_quantized`
### Stale Lock Fix: ✅ DONE (commit 845227c)
## B1 — FP8_E4M3 FMHA: ✅ DONE
**Implementation**: `dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh` + C API + Python bridge.
Storage-native DSV4 attention: noPE KV stays FP8_E4M3, RoPE KV stays BF16, no global FP8→BF16 dequant.
### Unit Test Results (2026-06-03, `tests/unit/test_b1_mixed_fp8_fmha.py`)
| Test | Status |
|------|--------|
| quantize_q_fp8_split | ✅ PASS (cos=0.9997) |
| gather_mixed kernels | ✅ PASS |
| FMHA cosine (N=128..2048, H=128) | ✅ PASS (cos=0.9999..0.9997) |
| Attention sinks | ✅ PASS |
| GQA/MQA (128 Q heads) | ✅ PASS |
| Weight loading verification | ✅ PASS |
| Batch sizes (B=1,2,4) | ✅ PASS |
### Bugs Found and Fixed
1. **V matrix canonical layout swap** (commit 4fe7f9d): `canon_idx_bf16_16x16(kk, dd)` was wrong — should be `canon_idx_bf16_16x16(dd, kk)`. The SMEM group structure was transposed vs the working TMA-loaded V in the multitile kernel. This caused cos=0.158 vs BF16 reference. After fix: cos=0.999972 at N=128.
### Known Limitations
- **Prefill batch size**: T=1..128 supported. For T>128, caller must split. T_BATCH=32 sub-batches used internally.
- Specialized for DSV4 HD=512/NOPE=448/ROPE=64.
### Bug Fix (2026-06-03)
1. **CRITICAL: T-dimension strides were wrong for T>1** — the kernel used `q_nope_head_stride` (stride(1) = T*NOPE) for the T dimension, but the correct stride is `stride(2) = NOPE`. For T=1 this is invisible (qr=0 always), but for T>1 it reads garbage from adjacent heads' data. Fix: added explicit T-dimension strides (`q_nope_t_stride`, `q_scale_t_stride`, `q_rope_t_stride`) to params struct, C API, and Python wrapper. All 16 T>1 test configs now pass (cos >= 0.999887).
## B2 — FP8 tensor-core indexer scoring: ✅ DONE
**Implementation**: `dsv4/kernels/cuda/indexer_fp8_score_topk.cu`
Native Blackwell FP8 GEMM via tcgen05 for CSA Lightning Indexer scoring. No PyTorch einsum fallback.
### Unit Test Results (2026-06-03, `tests/unit/test_b2_indexer_fp8.py`)
| Test | Status |
|------|--------|
| Score cosine vs FP32 reference (n_comp=128..8192) | ✅ PASS (100% overlap ≤1024, ~88% at 8192) |
| Score distribution sanity | ✅ PASS |
| Determinism | ✅ PASS |
| Edge cases (n_comp < top_k, n_comp=1) | ✅ PASS |
| Weight format verification | ✅ PASS |
### Bugs Found and Fixed
1. **Broken `16x256b.x1` TMEM read** — instruction was hanging. Root cause: the `16x256b.x1` PTX instruction either doesn't exist on SM100 or has different alignment requirements. **Fix**: use the proven `32x32b.x8` instruction from B1 FMHA.
2. **TMEM_COLS too small** — TMEM_COLS=128 was insufficient for the 128×128 MMA output. The MMA writes ALL 128 rows, requiring 4 row-groups × 128 columns = 512 TMEM columns. **Fix**: TMEM_COLS=512.
3. **Wrong TMEM offset for rows 32-63** — tried `tb + SK_TILE + col_base` and `tb + 16 + col_base`, both gave wrong results. **Root cause**: the `32x32b.x8` instruction maps different warps to different row slices from the SAME TMEM address. Warp 0 reads rows 0-31, warp 1 reads rows 32-63, all from `tb + col_base`. **Fix**: warps 0-1 both read from the same address, accumulate into separate SMEM partitions, then merge.
4. **Cross-warp accumulation race condition** — initial attempt used shared `sLogits[c]` with first-warp-writes/second-warp-adds pattern, which was non-deterministic. **Fix**: per-warp score partitions (`sWarpScores[0..SK_TILE-1]` and `sWarpScores[SK_TILE..2*SK_TILE-1]`), merged after `__syncthreads()`.
### Production Configuration
- n_ih=64, ihd=128, top_k=1024
- Warps 0-1: TMEM read + per-warp score accumulation
- Warp 4: MMA (FP8 GEMM)
- Per-thread local top-k (INDEXER_LOCAL_K=8) → block-level merge
## B3 — Fused rmsnorm→quant for q_a_norm / kv_norm: ✅ DONE
- `q_a_norm``q_b` path uses fused `rmsnorm_quantize_nvfp4` + `run_from_quantized`
- `kv_norm` still uses unfused rmsnorm — requires FP8 FMHA (B1) to fully benefit
## B4 — General "producer BF16 → consumer FP32" sweep: NOT STARTED
## B5 — Residual-stream precision: NOT STARTED (low priority)
---
# PART D — Dangling TODOS
- Batched Prefill: ✅ DONE (T=1..128, mixed FP8/BF16 kernel)
- Need to wire prefill into single_shot_inference.py (replace T=1 token-by-token prefill)
- Need T>128 support (split into multiple launches)