diff --git a/NEXT_STEPS.md b/NEXT_STEPS.md new file mode 100644 index 00000000..8cba410d --- /dev/null +++ b/NEXT_STEPS.md @@ -0,0 +1,133 @@ +# Next Steps — Post v0.1 E2E Working + +**Tag:** `v0.1-e2e-working` — Single-shot inference produces coherent output ("The capital of France is Paris") but has stability issues during multi-step decode. + +--- + +## The Mandate: Every Component Must Be Wired Up + +The single-shot script is NOT a test harness. It is a **reference implementation** that exercises the full production pipeline end-to-end. Every component must be connected and working together — mHC, compressor, indexer, attention, MoE, KV cache, RoPE, sinks. There is no "skip this for now" or "simplified path for short sequences." If a component is bypassed, we are not testing the real pipeline, and we will ship bugs into vLLM/SGLang integration. + +The compressor feeds compressed KV into the attention. The indexer selects which compressed entries to attend. The KV cache holds both SWA and compressed entries across decode steps. The mHC bounds the residual. Every piece depends on the others. A bug in the compressor silently corrupts attention, which corrupts the residual, which makes the model output garbage 30 steps later. The only way to catch these is to run the full pipeline. + +--- + +## Issue 1: Residual Growth in Later Layers (L56–60) + +**Symptom:** `|X|` grows to 300–500 by layer 60, and continues growing across decode steps (428→436→344→428→384 over 30 steps). The mHC should bound the residual via the doubly-stochastic B_l matrix and the sigmoid-constrained A_l/C_l. + +**Likely causes:** +- **mHC weight loading is correct** (verified against HF: [pre,post,comb] ordering, B^T, Sinkhorn from softmax). But the FP32 precision of the fused projection (Xn @ W.T) may differ from the HF path which uses DeepGEMM tf32_hc_prenorm_gemm with split-K. This could cause B_l to be slightly non-doubly-stochastic, allowing drift. +- **The `do_nvfp4_linear` dequant allocates a full (O, I) BF16 tensor every call.** This is slow and introduces BF16 quantization noise in the weight. The kernel path (tcgen05 MMA with NVFP4) avoids this. +- **The post_block accumulates in FP32** (CF.float() + BX) then casts to BF16. Loss of precision is expected but shouldn't cause unbounded growth. + +**Fix direction:** +- Compare per-layer B_l row/col sums against 1.0. If they drift, the Sinkhorn isn't converging (unlikely with t_max=20). +- Check if the residual growth matches what the HF reference produces for the same input. It may be expected — the model has 61 layers and the mHC doesn't guarantee bounded norms, just doubly-stochastic mixing. +- If growth is genuinely excessive, investigate: (a) using FP64 for the Sinkhorn, (b) clamping the residual (HF doesn't clamp), (c) checking the alpha scale values. + +**Kernel responsibility:** The mHC pre_block does `Xn @ W.T` as a Python FP32 matmul. The production path should use `tf32_hc_prenorm_gemm` from DeepGEMM (or our CuTeDSL equivalent). This is already in `dsv4/layers/mhc.py` (`_project_and_rms` method with `_HAS_DEEP_GEMM` guard). The single_shot bypasses the production mHCLayer and reimplements it inline — **this is a patch that should be the kernel's responsibility.** + +--- + +## Issue 2: Decode Quality Degradation After ~10 Steps + +**Symptom:** After generating a coherent initial response ("You're asking about the capital of France. The capital of France is **Paris**."), the model starts generating generic tokens like " like", " or" instead of continuing the response. + +**Likely causes:** +- **KV cache state management:** The SWA ring buffer and compressed KV grow across decode steps. After 10+ steps, the attention pattern shifts from mostly-SWA to mostly-compressed (for CSA/HCA layers). If the compressed KV is not properly accumulated (e.g., compressor only runs during prefill, not decode), later tokens see stale KV. +- **Compressor running during decode:** The single_shot runs `compressor.forward(x_normed, positions)` every step, including decode. For CSA (ratio=4), a single decode token can't form a complete window (needs 4 tokens). The compressor returns None for n_complete=0, which is correct — no new compressed entry is added. But after 4 decode tokens, a new compressed entry IS added. This is correct behavior but the transition may be sharp. +- **Block bias / causal masking:** The current implementation uses `block_bias = torch.zeros(...)` (all compressed entries visible to all tokens). For proper causal attention, earlier tokens should NOT see compressed entries from later windows. This could cause "future leaking" and degrade decode quality. +- **Attention score accumulation:** With growing KV sequence (compressed + SWA), the softmax denominator grows, potentially diluting attention to the most relevant positions. + +**Fix direction:** +- **Implement proper causal block_bias.** Token at position p should only attend to compressed entries whose window ends at or before p. This is critical for correctness. +- **Debug the KV cache state after 10+ decode steps.** Print: n_comp, swa_len, total seq_len per layer. Check if the sequence length grows as expected. +- **Compare decode output quality with/without compressed KV.** If the model generates better output with SWA-only attention, the compressor/indexer pipeline has a bug. + +**Kernel responsibility:** The attention mask / block_bias construction is currently in the single_shot. The production path should use the FMHA kernel's built-in causal mask + the sink merge logic from the kernel. The single_shot's `block_bias = torch.zeros(...)` is a patch that masks a missing feature. + +--- + +## Issue 3: Performance — 1.45s/token + +**Symptom:** Decode runs at ~1.45 seconds per token on the B200. Target: <100ms/token. + +**Bottlenecks:** +- **NVFP4 dequant allocates (O, I) BF16 tensor every call.** For 384-expert MoE with 7168×3072 weights, this is ~42M elements per expert, 6 experts per token = 252M elements dequant per token. Each dequant allocates, computes, then the allocation is freed. This is the dominant cost. +- **PyTorch SDPA for attention** instead of our FMHA kernel. The Python attention implementation does explicit matmul, softmax, matmul — all in BF16 on GPU, but without the FMHA kernel's SM100 tensor-core acceleration. +- **Per-expert loop in Python** instead of grouped GEMM. The MoE forward loops over 6 experts sequentially with 3 dequant+matmul calls each = 18 dequant+matmul per token. +- **No CUDA graphs.** Every kernel launch has Python overhead. +- **Weight streaming:** Weights are pre-cached on GPU, so this is not a bottleneck (already fixed in previous sessions). + +**Fix direction (in priority order):** +1. **Use the production FMHA kernel** (`dsv4/kernels/attention/production.py`) instead of PyTorch SDPA. Already proven at hd=512, 128 heads. +2. **Use the production MoE grouped GEMM kernel** (`dsv4/kernels/gemm/`) instead of Python expert loop. Already implemented as `FusedSwiGLUScaledGroupedGemmKernel`. +3. **Keep weights in NVFP4 and use tensor-core MMA** instead of dequant-to-BF16-then-matmul. This is the whole point of the kernel stack. +4. **CUDA graph capture** (E9 on roadmap) for decode. + +**Kernel responsibility:** All of this. The single_shot uses PyTorch fallbacks (dequant→BF16→matmul) because we needed to verify the math first. Now that the math is verified, we must replace every fallback with the production kernel path. The single_shot should call into `dsv4/layers/` and `dsv4/kernels/` instead of reimplementing the math. + +--- + +## Issue 4: Single-Shot Patches That Belong in the Kernel + +The single_shot reimplements several things that should be the kernel's responsibility. These must be migrated: + +| What | Single-shot patch | Where it belongs | +|---|---|---| +| NVFP4 dequant | `dequant_nvfp4()` → full (O,I) BF16 alloc | `dsv4/ops/quantize.py` → tcgen05 MMA with NVFP4 | +| mHC pre/post | Inline `mHCBlock` class | `dsv4/layers/mhc.py` (production `mHCLayer`) | +| Compressor | Inline `Compressor` class | `dsv4/kernels/compressor/` (CUDA kernel) | +| Indexer | Inline `Indexer` class | `dsv4/kernels/indexer/` (CUDA kernel) | +| Attention | PyTorch SDPA + explicit softmax | `dsv4/kernels/attention/production.py` (FMHA kernel) | +| MoE | Python expert loop + dequant | `dsv4/kernels/gemm/` (grouped GEMM) | +| Output projection | Manual grouped BMM | `dsv4/layers/grouped_linear.py` | +| KV cache | Simple ring buffer | `dsv4/cache/` (production paged + state cache) | +| RoPE | Inline `_apply_rope()` | `dsv4/ops/rope.py` (already exists) | +| RMSNorm | Inline `rmsnorm()` | `dsv4/layers/norm.py` (already exists) | + +**The migration plan:** Replace single_shot's inline implementations with calls to the production `dsv4/layers/` and `dsv4/kernels/` modules. The single_shot should become a thin orchestration layer: load weights → construct model → run inference. The heavy lifting should be in the kernel stack. + +The key invariant: **after each migration step, the single_shot must produce the same output.** If it doesn't, the kernel has a bug. This is the whole point of the reference implementation. + +--- + +## Issue 5: NVFP4 Dequant — input_scale Clarification + +**Critical finding:** The `input_scale` in the checkpoint is the FP8 activation quantization scale. It should NOT be folded into the weight dequant when using BF16 activations. The correct dequant is: + +``` +weight_bf16 = lut[weight_uint8] * weight_scale_e4m3 * weight_scale_2_scalar +``` + +NOT: +``` +weight_bf16 = lut[weight_uint8] * weight_scale_e4m3 * weight_scale_2_scalar * input_scale # WRONG +``` + +The `input_scale` would be used when the activation is also quantized to FP8 (the NVFP4-1.x path where both sides of the GEMM are FP4/FP8). For our current BF16-activation path, it must be excluded. This cost us a full debug cycle — the weights were ~4000x too small. + +**Kernel impact:** The production GEMM kernels (tcgen05 MMA with `mxf4nvf4`) handle this correctly by using separate weight and activation scales. But any Python fallback path must also get this right. + +--- + +## Immediate Next Steps (Priority Order) + +1. **Fix causal block_bias** in the compressor output. Token at position p must not attend to compressed entries from future windows. This is likely the main cause of decode degradation. +2. **Debug decode quality** by comparing SWA-only vs. full (compressed+SWA) attention at step 10+. If SWA-only is better, the compressor→attention pipeline has a bug. +3. **Replace PyTorch SDPA with production FMHA kernel** in the single_shot. The kernel is already proven (cos ≥ 0.999996 at hd=512). This should be a drop-in replacement. +4. **Replace Python MoE loop with production grouped GEMM** in the single_shot. +5. **Replace inline mHC with production mHCLayer** from `dsv4/layers/mhc.py`. Already has DeepGEMM integration. +6. **Profile residual growth** — determine if it matches the HF reference or is a bug. If expected, document it and move on. +7. **Performance tuning** — after kernel integration, benchmark and optimize. + +--- + +## Lessons From This Session + +1. **The checkpoint key format matters.** We had `layers.{li}.attn.*` hardcoded but the real format is `model.layers.{li}.self_attn.*`. Always probe the checkpoint first. +2. **The NVFP4 two-level scale has three components.** `weight_scale` (E4M3, per 16 elements), `weight_scale_2` (scalar, per projection), and `input_scale` (scalar, per projection). The `input_scale` is for FP8 activations, NOT for BF16. This is the #1 pitfall. +3. **Every component must be wired up.** The compressor, indexer, and KV cache are not optional. Without them, the model can "work" for 1-2 tokens on simple prompts but fails on real inference. The single_shot must exercise the full pipeline, always. +4. **Test with the harness.** Every run must go through `fire_b200_test` or `fire_b200_cuda_test`. Raw SSH execution is fragile and loses the kill/cleanup/timeout guarantees. +5. **The B200 is remote, code is local.** Edit locally → commit → push → pull on B200 → test. Never edit on B200.