Add PART_A_NEXT_SESSION.md: clues for decode degeneration debugging

This commit is contained in:
2026-06-03 04:34:28 +00:00
parent d8306be3f2
commit 1ebe7f0dde

59
PART_A_NEXT_SESSION.md Normal file
View File

@@ -0,0 +1,59 @@
# PART A — Next Session Clues
## The Core Mystery
FMHA per-layer cos = 0.999993 during **prefill**. All component tests pass. But decode output is degenerate (loops on "capitalizing"). The FMHA is NOT the problem. Something else in the pipeline produces wrong logits during decode.
## The Most Likely Suspects (ranked)
### 1. We Only Tested FMHA Cosine During PREFILL — NOT During DECODE
The test_production_fmha_layer.py runs 5 prefill tokens and checks cos on the LAST prefill token. But during **decode**, the attention operates on a growing KV cache with compressed entries + SWA window. The per-layer cos could be completely different during decode because:
- At decode step 5, the KV cache has SWA entries + compressed entries from the prefill
- The attention pattern is compressed (top-k or dense) + SWA — NOT the simple dense attention tested during prefill
- If the compressed/SWA gathering is off-by-one or has a parity issue, it would compound across 61 layers
**What to do:** Add verbose=2 logging to single_shot during DECODE (not just prefill). Print per-layer FMHA cos vs reference on the FIRST decode step. If L0-L3 are already wrong during decode, the bug is in the decode-time KV gathering, not the FMHA.
### 2. mHC Residual Growth Amplifies Errors
|X| grows to 472-732 over 61 layers. The paper says 300-500 is expected. But 700+ means the Sinkhorn isn't bounding properly. Each decode step passes this amplified residual through another 61 layers, compounding further. Even a tiny per-layer error (0.001% from FP8/BF16 quantization) gets amplified by the residual magnitude.
**What to do:** In single_shot, print |X| per layer during decode. If it's growing linearly (not bounded), the Sinkhorn t_max or the alpha values might be wrong. Check if the mHC weights (fn, base, scale) are loaded correctly from the checkpoint.
### 3. Compressed/SWA Visible-Range Parity (CORRECTNESS_BACKLOG A3)
During decode at step s, the query attends to `[top-k compressed entries] + [SWA window]`. Two things to verify:
1. **Causality**: A decode query must see only compressed blocks STRICTLY PRECEDING its own current (incomplete) block. If it sees its own block, it's attending to the future.
2. **SWA + compressed overlap**: The most recent tokens are in SWA AND may be in the newest compressed block → the query attends to both representations. This must match the reference, or the recent-context weighting drifts.
**What to do:** At decode step 10, print: (a) which compressed indices are visible, (b) SWA positions, (c) total seq_len. Compare against the PyTorch reference implementation in dsv4/reference/.
### 4. The Thinking Tokens Are a Red Herring (Probably)
The model emits ◇ (think_start), generates thinking content, then emits ◇ (think_end) and answers. If thinking content is degenerate, the answer will be wrong. BUT the degenerate thinking is a SYMPTOM, not a cause — if the logits are wrong, the thinking will be wrong, and the answer will be wrong. Fix the logits, and the thinking + answer will fix themselves.
## What NOT to Waste Time On
- FMHA kernel itself (cos 0.999993, it's correct)
- BF16 fallback paths (all removed)
- CuTeDSL/CUTLASS version issues (never the problem)
- Tokenizer / stop tokens (already tested, not the issue)
## The Test to Write First
A unit test that:
1. Runs single_shot through 5 prefill tokens (building KV cache)
2. On the FIRST decode step, for layers 0-3, compares the PRODUCTION FMHA output against PyTorch SDPA on the SAME gathered KV
3. This is exactly what test_production_fmha_layer.py does, but during DECODE instead of PREFILL
The key insight: if the per-layer cos is 0.999993 during prefill but 0.7 during decode, the bug is in the decode-time KV gathering or the compressed/SWA parity, NOT in the FMHA kernel itself.
## Concrete Code Locations to Read
- `single_shot_inference.py:860-900` — KV gathering during forward_attention (the "5. Gather KV" section)
- `single_shot_inference.py:660-720` — KVCache.gather_mixed_selective (CSA top-k gather)
- `single_shot_inference.py:722-750` — KVCache.gather_mixed_all (HCA dense gather)
- `dsv4/kernels/cuda/fp8_attention_io.cu:68-95` — copy_comp_rows_kernel (the actual CUDA gather)
- `dsv4/layers/mhc.py` — Sinkhorn-Knopp implementation
- `single_shot_inference.py:1455-1510` — Prefill loop (token-by-token, decode-style)