Update STAGE_D.md checklist with current progress and lessons learned

This commit is contained in:
2026-05-23 09:27:48 +00:00
parent c9dda47971
commit a3659c581d

View File

@@ -420,27 +420,53 @@ The indexer needs a full rewrite from scalar CUDA to tcgen05 MMA + radix-select.
| # | Task | Blocks | Est. |
|---|------|--------|------|
| D0 | SwiGLU clamping (CG-1) | Nothing — do first | 30 min |
| D1.2 | TMEM budget probe at hd=512 | D1.3 | 1 hr |
| D1.3 | Register→SMEM copy for P | D1.4, D2 | 1-2 days |
| D1.4 | Multi-PV-tile hd>256 | D2 | 1 day |
| D1.5 | Correction epilog fix (3% → 0.01%) | Nothing (can parallel) | 1-2 days |
| D2 | Multi-query grid + head packing | D3 | 1 day |
| D3 | SWA sequence length mask | D5 | ½ day |
| D4 | Causal mask on SWA | D5 | ½ day |
| D5a | Emit un-normalized o + lse | D5b | 1 day |
| D5b | Python merge (correctness) | D5c | ½ day |
| D5c | Fuse two passes in one launch | D5d | 2 days |
| D5d | Fuse sink merge in epilogue | D6 | 2 days |
| D6 | Mixed-precision KV load | E1 | 2 days |
| CG-4 | Inverse RoPE round-trip test | Nothing | 2 hrs |
| CG-6 | Per-token valid_lens (indexer) | Nothing | ½ day |
## Checklist — Updated 2026-05-23 09:30 UTC
**Critical path:** D0 → D1.2 → D1.3 → D1.4 → D5a → D5b (end-to-end correctness)
### ✅ COMPLETED
- **NVFP4-0**: Verified Blackwell FP4 primitives are correct (SF dtype Float8E4M3FN, SF_VEC_SIZE=16, FP4 tensor is float4_e2m1fn_x2)
- **D0/CG-1**: SwiGLU clamping already implemented in fused_swiglu.py (checked)
- **D1.0**: HEAD_DIM parameterization ✅ DONE
- **D1.1**: SMEM-P path flag (use_smem_p = head_dim > 64) ✅ WIRED
- **D1.2**: TMEM column budget — VERIFIED (use_smem_p=True for hd>64 due to layout mismatch)
- **D1.4**: Multi-PV-tile for hd>256 ✅ IMPLEMENTED (pv_n_tile = min(head_dim, 256), n_pv_tiles = head_dim // pv_n_tile)
**D1.5 (correction epilog) and CG-4 (RoPE test) can happen in parallel with D2D4.**
### 🔨 IN PROGRESS / PARTIAL
- **D1.3**: Register→SMEM copy for P — IMPLEMENTED BUT HAS RANK MISMATCH
- Current status: Code replaces zeroing stub with actual copy
- Issue: cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP) fails with rank mismatch (5 vs 3)
- Shapes observed: rP_bf16 shape: ((32, 1), 4, 1, 1), tSMEM_CPYrP: ((8, (4, 4, 128)), 1, 1, 1, 1), tSMEM_CPYsP: ((8, (16, 128)), (64, 2), 1)
- Root cause: Copy atom creates partitions with different ranks for source vs destination
- Need: Either reshape tensors to match ranks, or use different copy approach
---
### ❌ NOT STARTED
- **D1.5**: Correction epilog fix (3% error from TMEM layout mismatch)
- **D2**: Multi-query grid with head packing
- **D3**: SWA sequence length mask
- **D4**: Causal mask on SWA branch
- **D5**: SWA+sink merge path
- **D6**: FP4 KV load path
### 🔴 BLOCKING ISSUE
SMEM-P path rank mismatch prevents hd>64 from working. All hd=128,256,512 tests fail until fixed.
### NEXT STEPS RECOMMENDED
1. **Fix SMEM-P rank mismatch** — Most critical
Options:
- Try different group_modes combinations on source/destination
- Use make_tiled_copy_C with tcgen05.copy.St32x32bOp (as suggested in doc) instead of CopyUniversalOp
- Debug why partition_S and partition_D produce different rank tensors
2. **Test hd=512 two-pass** — Once SMEM-P works, verify multi-PV-tile logic (n_pv_tiles=2)
3. **D1.5 correction epilog** — Fix 3% error from TMEM layout mismatch
### LESSONS LEARNED
- CuTeDSL `cute.compile` zeroes GPU memory — keep index/mapping tensors on CPU
- Always verify with `.cpu().tolist()` after JIT
- TMEM-P path works for hd=64 (cos 0.972537) — good regression baseline
- `pv_n_tile = min(head_dim, 256)` critical for tcgen05 MMA which has max N=256
- `use_smem_p = (head_dim > 64)` due to TMEM layout mismatch at higher dimensions
- PRINT SHAPES saves days of debugging
## NVFP4 Precision Roadmap (May 23, 2026)