- Add dsv4_attention_mixed_fp8_prefill to production.py - _run_production_fmha_mixed now dispatches to prefill kernel for T>1 - Remove decode-only T==1 restriction - Update FINAL_STRETCH.md: prefill marked DONE, batched prefill TODO noted
97 lines
5.4 KiB
Markdown
97 lines
5.4 KiB
Markdown
# DSV4 Audit — Decode Repetition + Precision / Tensor-Core Plan
|
||
|
||
# PART B — Precision / NVFP4 / tensor-core (WE ARE SKIPPING PART A FOR RIGHT NOW AND WILL REVISIT IT)
|
||
|
||
Goal: native NVFP4 where the math allows, FP8_E4M3 where it doesn't, BF16/FP32 only where required. Validate each change with per-layer cosine vs `dsv4/reference` before trusting it.
|
||
|
||
## B0 — What's already optimal: DO NOT "fix" the MoE
|
||
`dsv4/layers/moe.py` already runs **native NVFP4**: expert weights and activations are `float4_e2m1fn_x2`, block scales are `float8_e4m3fn`. This matches the paper (routed experts in FP4). Leave it. The remaining wins are in **attention** and the **indexer**, not MoE.
|
||
|
||
### P5 — Fused mHC pre_block + RMSNorm + NVFP4 quantize: ✅ DONE
|
||
- `fused_mhc_rmsnorm_quantize.cu` — 2-kernel approach (mhc_rmsnorm_amax_gsa + mhc_rmsnorm_quantize_nvfp4)
|
||
- **Integrated into `forward_layer`** for BOTH attn and ffn mHC paths (commit 0b6ca0d)
|
||
- Unit test: cos=0.999 vs unfused, 0.995 vs true mHC+RMSNorm at T=1/8/128
|
||
|
||
### P4 — Fused RMSNorm + NVFP4 quantize: ✅ DONE
|
||
- `fused_rmsnorm_quantize.cu` — 2-kernel approach
|
||
- gsa scalar fix in `Nvfp4Linear.run_from_quantized`
|
||
|
||
### Stale Lock Fix: ✅ DONE (commit 845227c)
|
||
|
||
## B1 — FP8_E4M3 FMHA: ✅ DONE
|
||
|
||
**Implementation**: `dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh` + C API + Python bridge.
|
||
|
||
Storage-native DSV4 attention: noPE KV stays FP8_E4M3, RoPE KV stays BF16, no global FP8→BF16 dequant.
|
||
|
||
### Unit Test Results (2026-06-03, `tests/unit/test_b1_mixed_fp8_fmha.py`)
|
||
|
||
| Test | Status |
|
||
|------|--------|
|
||
| quantize_q_fp8_split | ✅ PASS (cos=0.9997) |
|
||
| gather_mixed kernels | ✅ PASS |
|
||
| FMHA cosine (N=128..2048, H=128) | ✅ PASS (cos=0.9999..0.9997) |
|
||
| Attention sinks | ✅ PASS |
|
||
| GQA/MQA (128 Q heads) | ✅ PASS |
|
||
| Weight loading verification | ✅ PASS |
|
||
| Batch sizes (B=1,2,4) | ✅ PASS |
|
||
|
||
### Bugs Found and Fixed
|
||
|
||
1. **V matrix canonical layout swap** (commit 4fe7f9d): `canon_idx_bf16_16x16(kk, dd)` was wrong — should be `canon_idx_bf16_16x16(dd, kk)`. The SMEM group structure was transposed vs the working TMA-loaded V in the multitile kernel. This caused cos=0.158 vs BF16 reference. After fix: cos=0.999972 at N=128.
|
||
|
||
### Known Limitations
|
||
- **Prefill batch size**: T=1..128 supported. For T>128, caller must split. T_BATCH=32 sub-batches used internally.
|
||
- Specialized for DSV4 HD=512/NOPE=448/ROPE=64.
|
||
|
||
### Bug Fix (2026-06-03)
|
||
1. **CRITICAL: T-dimension strides were wrong for T>1** — the kernel used `q_nope_head_stride` (stride(1) = T*NOPE) for the T dimension, but the correct stride is `stride(2) = NOPE`. For T=1 this is invisible (qr=0 always), but for T>1 it reads garbage from adjacent heads' data. Fix: added explicit T-dimension strides (`q_nope_t_stride`, `q_scale_t_stride`, `q_rope_t_stride`) to params struct, C API, and Python wrapper. All 16 T>1 test configs now pass (cos >= 0.999887).
|
||
|
||
## B2 — FP8 tensor-core indexer scoring: ✅ DONE
|
||
|
||
**Implementation**: `dsv4/kernels/cuda/indexer_fp8_score_topk.cu`
|
||
|
||
Native Blackwell FP8 GEMM via tcgen05 for CSA Lightning Indexer scoring. No PyTorch einsum fallback.
|
||
|
||
### Unit Test Results (2026-06-03, `tests/unit/test_b2_indexer_fp8.py`)
|
||
|
||
| Test | Status |
|
||
|------|--------|
|
||
| Score cosine vs FP32 reference (n_comp=128..8192) | ✅ PASS (100% overlap ≤1024, ~88% at 8192) |
|
||
| Score distribution sanity | ✅ PASS |
|
||
| Determinism | ✅ PASS |
|
||
| Edge cases (n_comp < top_k, n_comp=1) | ✅ PASS |
|
||
| Weight format verification | ✅ PASS |
|
||
|
||
### Bugs Found and Fixed
|
||
|
||
1. **Broken `16x256b.x1` TMEM read** — instruction was hanging. Root cause: the `16x256b.x1` PTX instruction either doesn't exist on SM100 or has different alignment requirements. **Fix**: use the proven `32x32b.x8` instruction from B1 FMHA.
|
||
|
||
2. **TMEM_COLS too small** — TMEM_COLS=128 was insufficient for the 128×128 MMA output. The MMA writes ALL 128 rows, requiring 4 row-groups × 128 columns = 512 TMEM columns. **Fix**: TMEM_COLS=512.
|
||
|
||
3. **Wrong TMEM offset for rows 32-63** — tried `tb + SK_TILE + col_base` and `tb + 16 + col_base`, both gave wrong results. **Root cause**: the `32x32b.x8` instruction maps different warps to different row slices from the SAME TMEM address. Warp 0 reads rows 0-31, warp 1 reads rows 32-63, all from `tb + col_base`. **Fix**: warps 0-1 both read from the same address, accumulate into separate SMEM partitions, then merge.
|
||
|
||
4. **Cross-warp accumulation race condition** — initial attempt used shared `sLogits[c]` with first-warp-writes/second-warp-adds pattern, which was non-deterministic. **Fix**: per-warp score partitions (`sWarpScores[0..SK_TILE-1]` and `sWarpScores[SK_TILE..2*SK_TILE-1]`), merged after `__syncthreads()`.
|
||
|
||
### Production Configuration
|
||
- n_ih=64, ihd=128, top_k=1024
|
||
- Warps 0-1: TMEM read + per-warp score accumulation
|
||
- Warp 4: MMA (FP8 GEMM)
|
||
- Per-thread local top-k (INDEXER_LOCAL_K=8) → block-level merge
|
||
|
||
## B3 — Fused rmsnorm→quant for q_a_norm / kv_norm: ✅ DONE
|
||
- `q_a_norm` → `q_b` path uses fused `rmsnorm_quantize_nvfp4` + `run_from_quantized`
|
||
- `kv_norm` still uses unfused rmsnorm — requires FP8 FMHA (B1) to fully benefit
|
||
|
||
## B4 — General "producer BF16 → consumer FP32" sweep: NOT STARTED
|
||
|
||
## B5 — Residual-stream precision: NOT STARTED (low priority)
|
||
|
||
---
|
||
|
||
# PART D — Dangling TODOS
|
||
|
||
- Batched Prefill: ✅ DONE (T=1..128, mixed FP8/BF16 kernel)
|
||
- Need to wire prefill into single_shot_inference.py (replace T=1 token-by-token prefill)
|
||
- Need T>128 support (split into multiple launches)
|