Add MAY_24_26_PLAN.md: next session startup plan
This commit is contained in:
98
MAY_24_26_PLAN.md
Normal file
98
MAY_24_26_PLAN.md
Normal file
@@ -0,0 +1,98 @@
|
||||
# May 24, 2026 — Session Start Plan
|
||||
|
||||
## Quick Context
|
||||
|
||||
You're working on the DSV4 (DeepSeek V4 Pro) NVFP4 inference kernel for Blackwell B200. The FMHA (Fused Multi-Head Attention) kernel is working at hd=64/128/256 (cos 0.999998). The next milestones are: fix O rescale for multi-KV-tile, add multi-head grid (D2), and verify NVFP4 primitives.
|
||||
|
||||
**B200:** `root@45.76.247.107`, pass: `<B200_PASSWORD>`
|
||||
**Repo:** `git@sweetapi.com:2222/biondizzle/nvfp4-megamoe-kernel.git`
|
||||
**Local:** `~/dev/nvfp4-megamoe-kernel`
|
||||
**Test command:** `~/.openclaw/workspace/fire_b200_test <test_file>`
|
||||
|
||||
## ⚡ Execute in This Order
|
||||
|
||||
### 1. NVFP4-0: Verify FP4 Primitives (20 min, NO CODE CHANGES)
|
||||
|
||||
These are **print-only diagnostics**. If any reveal a wrong dtype, stop and fix it before everything else.
|
||||
|
||||
- **NVFP4-0.1** — Trace `sf_dtype` through `gemm_runner.py` → `dense.py` → `blockscaled_utils`. NVFP4 uses FP8 E4M3 scales (NOT UE8M0 which is MXFP4). If the runner is passing E8M0, every FP4 GEMM is wrong.
|
||||
- **NVFP4-0.2** — Verify SF TMEM layout is UE4M3 packed (4 FP8 E4M3 per int32), NOT UE8M0 (MXFP8).
|
||||
- **NVFP4-0.3** — Verify `float4_e2m1fn_x2` survives into TMA descriptors (not downcast to uint8).
|
||||
- **NVFP4-0.4** — Verify tcgen05 MMA kind resolves to NVFP4 (16-element blocks, E4M3 scales), not MXFP4 (32-element, UE8M0).
|
||||
|
||||
**How:** Add `print()` calls in the Python layer, run any FP4 GEMM test, check output. Remove prints after.
|
||||
|
||||
### 2. Test O Rescale at s_k > 128 (30 min)
|
||||
|
||||
**The problem:** The O rescale code (for multi-KV-tile, kt>0) is guarded away with `const_expr(n_kv_tiles > 1)` at n=128. It uses hand-constructed TMEM atoms. **Untested and likely broken.**
|
||||
|
||||
**Why it matters NOW:** DSV4 Pro uses top_k=1024 → s_k=1024 → n_kv_tiles=8. D2 multi-head will exercise s_k>128. If rescale is broken, all D2 production tests fail.
|
||||
|
||||
**How to test:**
|
||||
1. Create `test_d1_multi_kv.py` with `FmhaKernel(head_dim=64, s_k=256, normalize=False)` (2 KV tiles)
|
||||
2. Run it on B200
|
||||
3. If cos < 0.99, O rescale is broken → fix before D2
|
||||
4. If cos ~0.999, rescale works → proceed to D2
|
||||
|
||||
**If broken, fix approach:** Replace hand-constructed TMEM round-trip with CUTLASS `correction_rescale_and_partition` pattern (one-way TMEM→SMEM). See STAGE_D.md D1.5 Issue 2.
|
||||
|
||||
### 3. Start D2: Multi-Query Grid (main work)
|
||||
|
||||
See `STAGE_D2.md` for the full plan. Summary:
|
||||
|
||||
- Add `num_query_heads` to `FmhaKernel` constructor
|
||||
- Change grid from `(1,1,1)` to `(ceil_div(T, 128), num_query_heads, batch)`
|
||||
- Map `block_idx` → `(m_tile, head_idx, batch_idx)` inside kernel
|
||||
- Q TMA indexed per-head, K/V shared (MQA)
|
||||
- Test with n_h=2 → n_h=8 → n_h=64/128
|
||||
|
||||
**First step:** Create `test_d2_multihead.py` with n_h=1 regression test (verify nothing breaks), then n_h=2.
|
||||
|
||||
### 4. NVFP4-3: use_2cta_instrs Conditional (30 min, parallel)
|
||||
|
||||
Pure perf win for MoE GEMMs. Add `use_2cta_instrs = (M >= 256 and cluster_m % 2 == 0)` in `gemm_runner.py`. 1.7–1.9× throughput at prefill shapes. No FMHA dependency.
|
||||
|
||||
### 5. NVFP4-1.1: Fuse FP4 into SwiGLU Epilogue (1 day, parallel)
|
||||
|
||||
Biggest bandwidth win. Current: L1 GEMM → SwiGLU → BF16 GMEM → quantize → FP4 GMEM → L2 GEMM. Target: L1 GEMM → SwiGLU → FP4 pack in registers → FP4 GMEM → L2 GEMM. Saves entire quantize kernel launch + 2× bandwidth. See STAGE_D.md for full spec.
|
||||
|
||||
---
|
||||
|
||||
## File Map (what to read for context)
|
||||
|
||||
| File | What it contains |
|
||||
|------|-----------------|
|
||||
| `STAGE_D.md` | Full FMHA kernel status, NVFP4 precision roadmap, D1.5 gaps |
|
||||
| `STAGE_D2.md` | D2 multi-query grid plan with 9-item to-do list |
|
||||
| `README.md` | Architecture, CuTeDSL constraints (#1–#16), test harness docs |
|
||||
| `dsv4/kernels/attention/fmha.py` | The FMHA kernel (518 lines, FmhaKernel class) |
|
||||
| `dsv4/model/config.py` | DSV4 dimensions: Flash n_h=64, Pro n_h=128, hd=512 |
|
||||
| `dsv4/ops/decode_sparse.py` | Sink merge formula, MQA op interface |
|
||||
| `MEMORY.md` | Long-term memory (B200 access, all stage results) |
|
||||
| `memory/2026-05-24.md` | Today's daily log (hd=512 SMEM fix, MLIR hang, all bug fixes) |
|
||||
|
||||
## Key Numbers
|
||||
|
||||
| Config | n_h | top_k | s_k | n_kv_tiles | O rescale needed? |
|
||||
|--------|----:|------:|----:|-----------:|:------------------|
|
||||
| Flash decode | 64 | 512 | 512 | 4 | YES |
|
||||
| Pro decode | 128 | 1024 | 1024 | 8 | YES |
|
||||
| Current test | 1 | — | 128 | 1 | No (guarded away) |
|
||||
|
||||
## D1 Status Summary
|
||||
|
||||
- ✅ hd=64/128/256: cos 0.999998, LSE err 0.0
|
||||
- ❌ hd=512: SMEM fits (192KB) but MLIR compilation hangs (3+ hours). External k_sub merge mathematically impossible. Need either: (a) pre-compile offline, (b) no-softmax mode for S accumulation, or (c) raw CUDA C++ kernel.
|
||||
- ⚠️ O rescale (kt>0): untested for s_k>128, likely broken
|
||||
- ✅ D5a (un-normalized O + LSE): done
|
||||
- ✅ D5b (Python sink merge): done, cos 0.961
|
||||
|
||||
## Rules (don't forget)
|
||||
|
||||
- NEVER edit on B200. Edit locally → commit → push → pull → test.
|
||||
- ALWAYS use `fire_b200_test` or `run_test.sh`. Never raw SSH+nohup.
|
||||
- ALWAYS verify hd=64 regression (cos ~0.999998) after every change.
|
||||
- Guard dead code with `const_expr()`. CuTeDSL compiles both branches.
|
||||
- CuTeDSL `if` blocks create separate MLIR regions — variables NOT visible across blocks.
|
||||
- After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.
|
||||
- NOHUP DOES NOT WORK on B200. Screen sessions survive SSH drops.
|
||||
Reference in New Issue
Block a user