plan update
This commit is contained in:
@@ -1,92 +1,96 @@
|
||||
# May 24, 2026 — Session Start Plan
|
||||
# May 24–26, 2026 — Session Plan & Progress
|
||||
|
||||
## Quick Context
|
||||
|
||||
You're working on the DSV4 (DeepSeek V4 Pro) NVFP4 inference kernel for Blackwell B200. The FMHA (Fused Multi-Head Attention) kernel is working at hd=64/128/256 (cos 0.999998). The next milestones are: fix O rescale for multi-KV-tile, add multi-head grid (D2), and verify NVFP4 primitives.
|
||||
You're working on the DSV4 (DeepSeek V4 Pro) NVFP4 inference kernel for Blackwell B200. The FMHA (Fused Multi-Head Attention) kernel is working at hd=64/128/256 (cos 0.999998). D5 (sink merge) is COMPLETE. The next milestones are: fix D1.5 O rescale, then proceed to production extraction (Stage E).
|
||||
|
||||
**B200:** See MEMORY.md for access (not committed to repo)
|
||||
**Repo:** `git@sweetapi.com:2222/biondizzle/nvfp4-megamoe-kernel.git`
|
||||
**Local:** `~/dev/nvfp4-megamoe-kernel`
|
||||
**Test command:** `~/.openclaw/workspace/fire_b200_test <test_file>`
|
||||
|
||||
## ⚡ Execute in This Order
|
||||
## ✅ Completed (May 24–26)
|
||||
|
||||
### 1. NVFP4-0: Verify FP4 Primitives (20 min, NO CODE CHANGES)
|
||||
### 1. NVFP4-0: Verify FP4 Primitives ✅
|
||||
- All four diagnostics PASS — sf_dtype, TMA element type, MMA kind all correct
|
||||
- NVFP4 uses FP8 E4M3 scales (NOT UE8M0), 16-element blocks, tcgen05 MMA kind correct
|
||||
|
||||
These are **print-only diagnostics**. If any reveal a wrong dtype, stop and fix it before everything else.
|
||||
### 2. NVFP4-3: use_2cta_instrs Conditional ✅
|
||||
- tokens_sum >= 256 and cluster_m even → 2-CTA UMMA
|
||||
- 1.7-1.9× throughput at prefill shapes. Decode stays 1-CTA.
|
||||
|
||||
- **NVFP4-0.1** — Trace `sf_dtype` through `gemm_runner.py` → `dense.py` → `blockscaled_utils`. NVFP4 uses FP8 E4M3 scales (NOT UE8M0 which is MXFP4). If the runner is passing E8M0, every FP4 GEMM is wrong.
|
||||
- **NVFP4-0.2** — Verify SF TMEM layout is UE4M3 packed (4 FP8 E4M3 per int32), NOT UE8M0 (MXFP8).
|
||||
- **NVFP4-0.3** — Verify `float4_e2m1fn_x2` survives into TMA descriptors (not downcast to uint8).
|
||||
- **NVFP4-0.4** — Verify tcgen05 MMA kind resolves to NVFP4 (16-element blocks, E4M3 scales), not MXFP4 (32-element, UE8M0).
|
||||
### 3. D1: Parameterized HEAD_DIM ✅ (hd≤256)
|
||||
- hd=64/128/256: cos 0.999998, LSE err 0.0
|
||||
- hd=512: BLOCKED by MLIR compilation hang (>3hr). External k_sub merge impossible.
|
||||
- D1.5 O rescale: BLOCKED by TMEM round-trip corruption. Python KV merge workaround.
|
||||
|
||||
**How:** Add `print()` calls in the Python layer, run any FP4 GEMM test, check output. Remove prints after.
|
||||
### 4. D2: Multi-Query Grid ✅ (per-head launch)
|
||||
- Head-packed M-dimension: Q reshaped to (n_h*T, hd, 1), per-row softmax
|
||||
- cos 0.999995 for n_h=1-128 at hd=64, n_h=2-8 at hd=128, n_h=2 at hd=256
|
||||
- Multi-CTA grid: BLOCKED by flat_divide + epilogue_tma_store mismatch
|
||||
|
||||
### 2. Test O Rescale at s_k > 128 (30 min)
|
||||
### 5. D3: SWA Sequence Length Mask ✅
|
||||
- In-kernel post-QK masking via tTMEM_LOADcS coordinates
|
||||
- swa_len as Int32 scalar, offset by n_comp for D5c
|
||||
|
||||
**The problem:** The O rescale code (for multi-KV-tile, kt>0) is guarded away with `const_expr(n_kv_tiles > 1)` at n=128. It uses hand-constructed TMEM atoms. **Untested and likely broken.**
|
||||
### 6. D4: Causal Mask ✅
|
||||
- SWA-relative position (kv_pos - n_comp) > m_coord → -inf
|
||||
- Combined with D3 via OR logic
|
||||
|
||||
**Why it matters NOW:** DSV4 Pro uses top_k=1024 → s_k=1024 → n_kv_tiles=8. D2 multi-head will exercise s_k>128. If rescale is broken, all D2 production tests fail.
|
||||
### 7. D5: SWA + Sink Merge ✅
|
||||
- D5a: normalize flag + LSE + row_sums output
|
||||
- D5b: Per-row LSE + Python KV merge (cos 0.999994)
|
||||
- D5c: Sink bias as logit modification — **key insight: sink merge = single softmax over [S_comp, S_swa + attn_sink]**
|
||||
- Single-tile: cos 0.999996
|
||||
- Multi-tile (Python KV merge): cos 0.999996
|
||||
- D5d NOT NEEDED — sink bias approach supersedes fused merge epilogue
|
||||
|
||||
**How to test:**
|
||||
1. Create `test_d1_multi_kv.py` with `FmhaKernel(head_dim=64, s_k=256, normalize=False)` (2 KV tiles)
|
||||
2. Run it on B200
|
||||
3. If cos < 0.99, O rescale is broken → fix before D2
|
||||
4. If cos ~0.999, rescale works → proceed to D2
|
||||
## 🎯 Next Priorities (in order)
|
||||
|
||||
**If broken, fix approach:** Replace hand-constructed TMEM round-trip with CUTLASS `correction_rescale_and_partition` pattern (one-way TMEM→SMEM). See STAGE_D.md D1.5 Issue 2.
|
||||
### Priority 1: D1.5 — Fix O Rescale for Multi-KV-Tile (BLOCKER)
|
||||
**Why:** Production DSV4 Pro decode needs s_k=1152 (9 KV tiles). Python KV merge works but requires 5-9 kernel launches per decode step.
|
||||
|
||||
### 3. Start D2: Multi-Query Grid (main work)
|
||||
**Approaches (ordered by feasibility):**
|
||||
1. **Correction epilog pattern** (1-2 days): One-way TMEM→regs→SMEM→GMEM. Study CUTLASS reference `correction_rescale_and_partition` + `epilogue_tmem_copy_and_partition`. This is the proper Blackwell pipeline.
|
||||
2. **Skip TMEM round-trip entirely**: After softmax, write P to SMEM, PV accumulate to SMEM, TMA store from SMEM. Requires SMEM budget for both P and O.
|
||||
3. **Python KV merge as production path**: Accept the multi-launch overhead. Profile to see if it's actually a bottleneck (5-9 launches × ~50μs ≈ 250-450μs, vs ~50μs for single launch with O rescale).
|
||||
|
||||
See `STAGE_D2.md` for the full plan. Summary:
|
||||
### Priority 2: Stage E — Production Extraction
|
||||
**Why:** D5 is complete. The kernel works. Time to wrap it in a proper interface.
|
||||
|
||||
- Add `num_query_heads` to `FmhaKernel` constructor
|
||||
- Change grid from `(1,1,1)` to `(ceil_div(T, 128), num_query_heads, batch)`
|
||||
- Map `block_idx` → `(m_tile, head_idx, batch_idx)` inside kernel
|
||||
- Q TMA indexed per-head, K/V shared (MQA)
|
||||
- Test with n_h=2 → n_h=8 → n_h=64/128
|
||||
- E1: File placement (already done — `dsv4/kernels/attention/fmha.py`)
|
||||
- E2: Constructor signature (partially done — needs cleanup)
|
||||
- E3: Call signature (needs sink_bias, row_sums, n_comp integration)
|
||||
- E4: Kernel cache + warmup (key on n_comp, apply_sink_bias, head_dim, s_k)
|
||||
- E5: torch.library custom op
|
||||
- E6: Reference parity test
|
||||
- E7: Cleanup (delete debug test files)
|
||||
|
||||
**First step:** Create `test_d2_multihead.py` with n_h=1 regression test (verify nothing breaks), then n_h=2.
|
||||
### Priority 3: NVFP4-1.1 — Fuse FP4 into SwiGLU Epilogue (1 day, parallel)
|
||||
**Why:** Biggest bandwidth win for MoE pipeline. No FMHA dependency. Can work in parallel with D1.5.
|
||||
|
||||
### 4. NVFP4-3: use_2cta_instrs Conditional (30 min, parallel)
|
||||
Current: L1 GEMM → SwiGLU → BF16 GMEM → quantize → FP4 GMEM → L2 GEMM
|
||||
Target: L1 GEMM → SwiGLU → FP4 pack in registers → FP4 GMEM → L2 GEMM
|
||||
|
||||
Pure perf win for MoE GEMMs. Add `use_2cta_instrs = (M >= 256 and cluster_m % 2 == 0)` in `gemm_runner.py`. 1.7–1.9× throughput at prefill shapes. No FMHA dependency.
|
||||
### Priority 4: hd=512 Fix (BLOCKED by MLIR)
|
||||
**Status:** Kernel structurally correct (tracer 0.8s). MLIR optimizer hangs for 3+ hours.
|
||||
**Options:**
|
||||
1. Pre-compile offline + cache cubin (if MLIR eventually finishes)
|
||||
2. Write hd=512 path in raw CUTLASS C++ (bypass CuTeDSL MLIR)
|
||||
3. Report bug to NVIDIA
|
||||
|
||||
### 5. NVFP4-1.1: Fuse FP4 into SwiGLU Epilogue (1 day, parallel)
|
||||
|
||||
Biggest bandwidth win. Current: L1 GEMM → SwiGLU → BF16 GMEM → quantize → FP4 GMEM → L2 GEMM. Target: L1 GEMM → SwiGLU → FP4 pack in registers → FP4 GMEM → L2 GEMM. Saves entire quantize kernel launch + 2× bandwidth. See STAGE_D.md for full spec.
|
||||
|
||||
---
|
||||
|
||||
## File Map (what to read for context)
|
||||
|
||||
| File | What it contains |
|
||||
|------|-----------------|
|
||||
| `STAGE_D.md` | Full FMHA kernel status, NVFP4 precision roadmap, D1.5 gaps |
|
||||
| `STAGE_D2.md` | D2 multi-query grid plan with 9-item to-do list |
|
||||
| `README.md` | Architecture, CuTeDSL constraints (#1–#16), test harness docs |
|
||||
| `dsv4/kernels/attention/fmha.py` | The FMHA kernel (518 lines, FmhaKernel class) |
|
||||
| `dsv4/model/config.py` | DSV4 dimensions: Flash n_h=64, Pro n_h=128, hd=512 |
|
||||
| `dsv4/ops/decode_sparse.py` | Sink merge formula, MQA op interface |
|
||||
| `MEMORY.md` | Long-term memory (B200 access, all stage results) |
|
||||
| `memory/2026-05-24.md` | Today's daily log (hd=512 SMEM fix, MLIR hang, all bug fixes) |
|
||||
### Priority 5: D2 Multi-CTA Grid (BLOCKED by flat_divide)
|
||||
**Status:** Per-head launch works for decode. Multi-CTA needed for prefill.
|
||||
**Requires:** Full tma_partition + epilogue refactor into kernel (1-2 day effort).
|
||||
|
||||
## Key Numbers
|
||||
|
||||
| Config | n_h | top_k | s_k | n_kv_tiles | O rescale needed? |
|
||||
|--------|----:|------:|----:|-----------:|:------------------|
|
||||
| Flash decode | 64 | 512 | 512 | 4 | YES |
|
||||
| Pro decode | 128 | 1024 | 1024 | 8 | YES |
|
||||
| Flash decode | 64 | 512 | 640 | 5 | YES |
|
||||
| Pro decode | 128 | 1024 | 1152 | 9 | YES |
|
||||
| Current test | 1 | — | 128 | 1 | No (guarded away) |
|
||||
|
||||
## D1 Status Summary
|
||||
|
||||
- ✅ hd=64/128/256: cos 0.999998, LSE err 0.0
|
||||
- ❌ hd=512: SMEM fits (192KB) but MLIR compilation hangs (3+ hours). External k_sub merge mathematically impossible. Need either: (a) pre-compile offline, (b) no-softmax mode for S accumulation, or (c) raw CUDA C++ kernel.
|
||||
- ⚠️ O rescale (kt>0): untested for s_k>128, likely broken
|
||||
- ✅ D5a (un-normalized O + LSE): done
|
||||
- ✅ D5b (Python sink merge): done, cos 0.961
|
||||
|
||||
## Rules (don't forget)
|
||||
|
||||
- NEVER edit on B200. Edit locally → commit → push → pull → test.
|
||||
|
||||
57
README.md
57
README.md
@@ -138,7 +138,7 @@ Summary
|
||||
|
||||
---
|
||||
|
||||
## Status (May 25, 2026 — 01:10 UTC)
|
||||
## Status (May 26, 2026 — 18:40 UTC)
|
||||
|
||||
| Stage | Status | Description |
|
||||
|-------|--------|-------------|
|
||||
@@ -146,10 +146,11 @@ Summary
|
||||
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
|
||||
| C | ✅ COMPLETE | Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. |
|
||||
| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. O rescale TMEM round-trip BROKEN (Ld32x32bOp/St32x32bOp corrupt data). Python KV merge workaround works. |
|
||||
| D2 | 🟡 Per-head DONE | Multi-query grid. Per-head launch works (cos 0.999998, n_h=1-64 hd=64, n_h=2-8 hd=128, n_h=2 hd=256). Multi-CTA grid blocked: `flat_divide` + `epilogue_tma_store` layout mismatch. Requires full tma_partition refactor into kernel. |
|
||||
| D3 | ✅ DONE | SWA sequence length mask (in-kernel post-QK via tTMEM_LOADcS coordinates, swa_len Int32 scalar) |
|
||||
| D4 | ✅ DONE | Causal mask on SWA branch (k_coord > m_coord → -inf, combined with D3 via OR logic) |
|
||||
| D5 | 🟢 D5a+D5b+D5c DONE | D5a: normalize flag + LSE output. D5b: Per-row LSE + Python KV merge (cos 0.999994). D5c: Sink bias (attn_sink) as logit modification in combined KV (cos 0.999996, single KV tile). Multi-tile blocked by D1.5. |
|
||||
| D1.5 | ❌ BLOCKER | O rescale for multi-KV-tile (kt>0). TMEM round-trip corruption (even NO-OP round-trip fails). Python KV merge workaround: cos 0.999994. Production: 5-9 kernel launches per decode. Fix requires correction epilog (one-way TMEM→regs→SMEM→GMEM). |
|
||||
| D2 | 🟡 Per-head DONE | Head-packed M-dimension launch (cos 0.999995, n_h=1-128). Multi-CTA grid blocked: `flat_divide` + `epilogue_tma_store` layout mismatch. |
|
||||
| D3 | ✅ DONE | SWA sequence length mask (in-kernel post-QK via tTMEM_LOADcS coordinates, swa_len Int32 scalar, offset by n_comp for D5c) |
|
||||
| D4 | ✅ DONE | Causal mask on SWA branch (SWA-relative position > m_coord → -inf, combined with D3 via OR logic) |
|
||||
| D5 | ✅ D5a+D5b+D5c DONE | D5a: normalize flag + LSE + row_sums output. D5b: Per-row LSE + Python KV merge (cos 0.999994). D5c: Sink bias as logit modification — mathematically equivalent to separate merge, single pass over combined KV (cos 0.999996 single-tile AND multi-tile). D5d (fused in-kernel merge) NOT NEEDED — sink bias approach supersedes it. |
|
||||
| E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) |
|
||||
| NVFP4-3 | ✅ DONE | `use_2cta_instrs` conditional in gemm_runner.py. 1.7-1.9× throughput at prefill shapes. |
|
||||
|
||||
@@ -231,7 +232,9 @@ dsv4/
|
||||
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
|
||||
| `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** |
|
||||
| `test_fmha_v3_stage_d1.py` | D1 | ✅ hd=64/128/256 PASS (cos 0.999998, TMEM-P). hd=512 SMEM overflow. |
|
||||
| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.961, LSE err=0.0) |
|
||||
| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.999994, LSE err=0.0) |
|
||||
| `test_d5c_fused.py` | D5c | ✅ Single-tile combined KV + sink bias (cos 0.999996) |
|
||||
| `test_d5c_multitile.py` | D5c | ✅ Multi-tile with Python KV merge + sink bias (cos 0.999996) |
|
||||
| `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
|
||||
| `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments |
|
||||
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
|
||||
@@ -394,6 +397,9 @@ Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
|
||||
18. **KV merge formula uses NORMALIZED outputs, not un-normalized.** The correct D5 merge for different token sets: `O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]`. Using `O_i_unnorm` instead of `O_i_norm` gives cos ~0.91. The un-norm merge only works when both segments share the same `row_max` (global max), which isn't the case for separate KV segments.
|
||||
19. **`flat_divide` + `epilogue_tma_store` layout mismatch.** When using `cute.flat_divide` to create per-CTA GMEM views with runtime block coordinates (for multi-CTA grid), the resulting tensor layout is incompatible with CUTLASS's `epilogue_tma_store` pipeline, which expects the layout from `local_tile`. The tma_partition and epilogue must be refactored together to support multi-CTA grids.
|
||||
20. **`local_tile` does not support runtime coordinates.** `cute.local_tile(mQ, tiler, (runtime_val, None))` fails at trace time. Must use `cute.flat_divide(mQ, tiler)` instead, which creates a tiled view with all rest dimensions accessible via runtime indexing.
|
||||
21. **Sink bias domain correction.** Adding `attn_sink` directly to raw logits is wrong — it gets scaled by `scale_log2`. Fix: add `attn_sink / scale` to raw logits, so after `* scale_log2` it becomes `attn_sink * log2(e)`, correctly multiplying attention weights by `exp(attn_sink)`.
|
||||
22. **O normalization uses row_sum, NOT LSE.** `O_norm = O_unnorm / row_sum` is correct. `O_unnorm * exp(-LSE)` is WRONG because O_unnorm is max-shifted (divided by `2^row_max`), not raw `exp(S) @ V`. The kernel now outputs `row_sum` alongside LSE.
|
||||
23. **n_comp is compile-time, swa_len is runtime.** The `n_comp` parameter controls `const_expr` guards in the kernel and cannot vary between segments of the same kernel instance. `swa_len` is an `Int32` scalar and can vary per request. For multi-tile production, use a kernel cache keyed on `(n_comp, apply_sink_bias, head_dim, s_k)`.
|
||||
|
||||
---
|
||||
|
||||
@@ -459,7 +465,7 @@ The SWA branch needs a causal mask within the window. Add `is_causal: bool` cons
|
||||
|
||||
Done when: prefill mode produces correct output with the causal mask applied to SWA.
|
||||
|
||||
**D5 — SWA + sink merge** (~2-3 days) ← D5a+D5b DONE (May 23), D5c/D5d remaining
|
||||
**D5 — SWA + sink merge** ✅ COMPLETE (May 26)
|
||||
|
||||
Per `dsv4/ops/decode_sparse.py`:
|
||||
```
|
||||
@@ -467,41 +473,36 @@ o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
|
||||
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
||||
```
|
||||
|
||||
With un-normalized O (D5a): `o_unnorm = o_norm * exp(lse)`, so:
|
||||
**Key insight (May 26):** This merge is mathematically identical to a single attention pass over concatenated KV with a logit bias on SWA positions:
|
||||
```
|
||||
o = (o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa)
|
||||
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
|
||||
S = [S_comp, S_swa + attn_sink]
|
||||
O = softmax(S) @ [V_comp; V_swa]
|
||||
```
|
||||
This means D5c is a **logit bias addition**, not a two-pass + merge kernel. D5d (fused in-kernel merge epilogue) is NO LONGER NEEDED.
|
||||
|
||||
**D5a DONE (May 23):** `normalize` flag added to FmhaKernel. When False, emits un-normalized O + LSE. LSE formula: `lse = ln(row_sum) + row_max * ln(2)` (row_max in scale_log2 domain, multiply by ln(2) to convert). LSE err=0.000000 verified.
|
||||
**D5a ✅:** `normalize` flag + LSE + row_sums output. When False, emits un-normalized O + LSE + row_sum.
|
||||
|
||||
**D5b DONE (May 23):** Python SWA+sink merge works end-to-end at hd=64. Run FMHA twice (compressed KV + SWA KV, normalize=False), merge in Python. Merge cos 0.961, individual attention cos 0.963/0.960.
|
||||
**D5b ✅:** Per-row LSE output (all 128 rows now write). Python KV merge with per-row LSE: cos 0.999994.
|
||||
|
||||
Sub-steps remaining:
|
||||
- **5c:** Fuse the two passes into one kernel launch. Q stays in SMEM, two MMA loops sequentially.
|
||||
- **5d:** Fuse the merge into the kernel epilogue.
|
||||
**D5c ✅:** Sink bias as logit modification. Parameters: `n_comp` (compressed KV length, compile-time), `apply_sink_bias` (compile-time flag), `sink_bias` (runtime FP32 tensor). Sink bias added to raw logits as `attn_sink / scale` so after `* scale_log2` it correctly becomes `attn_sink * log2(e)` in the exp2 domain. Multi-tile via Python KV merge: cos 0.999996.
|
||||
|
||||
Done when: end-to-end kernel produces correct attention against FP32 oracle that does sparse+SWA+sink merge.
|
||||
**D5d:** NOT NEEDED. The sink bias approach makes a fused merge epilogue unnecessary.
|
||||
|
||||
Done when: ✅ End-to-end kernel produces correct attention against FP32 oracle.
|
||||
|
||||
**~~D5 (old) paged TMA~~ — REMOVED.** The indexer + gather handles all paging upstream.
|
||||
|
||||
### Kernel Architecture (after D5)
|
||||
### Kernel Architecture (after D5 — COMPLETE)
|
||||
|
||||
```
|
||||
Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512]
|
||||
swa_lens [batch], sink_logits [n_h], request_ids [T]
|
||||
│
|
||||
├─ Load Q to SMEM (once)
|
||||
│
|
||||
├─ Loop 1: compressed KV (top_k tokens)
|
||||
│ QK → online softmax → PV → O_sparse, lse_sparse in TMEM
|
||||
│
|
||||
├─ Loop 2: SWA window (n_win tokens, masked by swa_lens)
|
||||
│ QK → online softmax → PV → O_swa, lse_swa in TMEM
|
||||
│
|
||||
└─ Sink merge epilogue:
|
||||
O = (exp(lse_sparse) * O_sparse + exp(sink) * exp(lse_swa) * O_swa)
|
||||
/ (exp(lse_sparse) + exp(sink) * exp(lse_swa))
|
||||
└─ Single pass over concatenated KV [compressed_kv; swa_kv]:
|
||||
QK → online softmax (with sink bias on SWA, D3/D4 masking) → PV
|
||||
→ O_unnorm + LSE + row_sum
|
||||
→ External normalize: O = O_unnorm / row_sum
|
||||
→ (Multi-tile: Python KV merge across 128-token segments)
|
||||
```
|
||||
|
||||
### Reference Files
|
||||
|
||||
53
STAGE_D.md
53
STAGE_D.md
@@ -31,7 +31,7 @@
|
||||
|
||||
---
|
||||
|
||||
## Current Status (2026-05-24, 21:30 UTC)
|
||||
## Current Status (2026-05-26, 18:40 UTC)
|
||||
|
||||
### ✅ WORKING
|
||||
|
||||
@@ -41,12 +41,21 @@
|
||||
| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | 128KB |
|
||||
| 256 | 0.999998 | 0.000000 | TMEM-P | 224KB |
|
||||
|
||||
### ✅ D5 COMPLETE (May 26)
|
||||
|
||||
| Test | Config | cos | Status |
|
||||
|------|--------|-----|--------|
|
||||
| D5c single-tile | n_comp=64, n_swa=64, sink=0.5 | 0.999996 | ✅ |
|
||||
| D5c causal | n_comp=64, n_swa=64, sink=0.3, causal | 0.999996 | ✅ |
|
||||
| D5c multi-tile | n_comp=96, n_swa=160, s_k=256, Python KV merge | 0.999996 | ✅ |
|
||||
| D3 regression | In-kernel mask, s_k=128 | 0.999996 | ✅ |
|
||||
| D4 regression | Causal mask, s_k=128 | 0.999996 | ✅ |
|
||||
|
||||
### ❌ KNOWN ISSUES
|
||||
|
||||
- **hd=512: MLIR compilation hangs.** SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). Both `range()` unrolled and `cutlass.range(unroll=1)` runtime loops trigger this. This is a CuTeDSL/MLIR toolchain limitation.
|
||||
- **External k_sub merge doesn't work.** k_sub segments are additive in logit space (S = S_0 + S_1), not attention weight space. The D5 merge formula does not apply. In-kernel k_sub accumulation is the only correct approach.
|
||||
- **O rescale (kt>0):** Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. Guarded with `const_expr(n_kv_tiles > 1)`.
|
||||
- **Kernel always outputs un-normalized O + LSE.** No in-kernel normalization (eliminates TMEM round-trip error). External normalization: `O_norm = O_unnorm / row_sum`.
|
||||
- **hd=512: MLIR compilation hangs.** SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). This is a CuTeDSL/MLIR toolchain limitation.
|
||||
- **D1.5 O rescale (multi-KV-tile): TMEM round-trip corruption.** Hand-constructed Ld32x32bOp/St32x32bOp atoms corrupt data on round-trip (even NO-OP). Workaround: Python KV merge (cos 0.999994). Fix requires correction epilog pattern (one-way TMEM→regs→SMEM→GMEM).
|
||||
- **D2 multi-CTA grid: flat_divide + epilogue_tma_store layout mismatch.** Requires full tma_partition refactor into kernel. Head-packed per-head launch works (cos 0.999995).
|
||||
|
||||
---
|
||||
|
||||
@@ -198,28 +207,32 @@ pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit S
|
||||
3. **Write hd=512 kernel in CUTLASS C++.** Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds).
|
||||
4. **Report CuTeDSL MLIR optimizer bug.** The optimizer should handle this IR in reasonable time. File an issue with NVIDIA.
|
||||
|
||||
### D2 — Multi-Query Grid with Head Packing
|
||||
### D2 — Multi-Query Grid with Head Packing ✅ (per-head launch)
|
||||
|
||||
- Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)`
|
||||
- DSV4 is MQA: all 128 query heads share same K/V
|
||||
- Head axis folded into M dimension: `M_tile = 128` covers `M = T * n_h` rows
|
||||
- Head-packed M-dimension launch: Q reshaped to (n_h*T, hd, 1), kernel treats each row independently
|
||||
- cos 0.999995 for n_h=1-128 at hd=64, n_h=2-8 at hd=128, n_h=2 at hd=256
|
||||
- Multi-CTA grid (flat_divide) BLOCKED — see Known Issues
|
||||
|
||||
### D3 — SWA Sequence Length Mask
|
||||
### D3 — SWA Sequence Length Mask ✅
|
||||
|
||||
- Add `swa_lens: [batch] int32` kernel input
|
||||
- Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]`
|
||||
- In-kernel post-QK masking via tTMEM_LOADcS coordinates
|
||||
- swa_len as Int32 scalar (runtime, not compile-time)
|
||||
- Offset by n_comp for D5c: mask positions >= n_comp + swa_len
|
||||
|
||||
### D4 — Causal Mask on SWA Branch
|
||||
### D4 — Causal Mask on SWA Branch ✅
|
||||
|
||||
- Add `is_causal: bool` constructor flag
|
||||
- Apply `swa_idx > q_pos` masking to `-inf` in SWA pass
|
||||
- SWA-relative position (kv_pos - n_comp) > m_coord → -inf
|
||||
- Combined with D3 via OR logic
|
||||
|
||||
### D5 — SWA + Sink Merge
|
||||
### D5 — SWA + Sink Merge ✅ (May 26)
|
||||
|
||||
- **D5a ✅:** Kernel outputs un-normalized O + LSE
|
||||
- **D5b ✅:** Python merge works (cos 0.961 at hd=64)
|
||||
- **D5c:** Fuse two passes into one kernel launch
|
||||
- **D5d:** Fuse sink merge into kernel epilogue
|
||||
**Key insight:** Sink merge = single softmax over [S_comp, S_swa + attn_sink].
|
||||
One pass, one kernel. D5d NOT NEEDED.
|
||||
|
||||
- **D5a ✅:** normalize flag + LSE + row_sums output
|
||||
- **D5b ✅:** Per-row LSE + Python KV merge (cos 0.999994)
|
||||
- **D5c ✅:** Sink bias as logit modification (cos 0.999996 single-tile AND multi-tile)
|
||||
- **D5d:** NOT NEEDED — sink bias approach supersedes fused merge epilogue
|
||||
|
||||
---
|
||||
|
||||
|
||||
11
STAGE_D2.md
11
STAGE_D2.md
@@ -204,10 +204,14 @@ O has shape `(batch, n_h, T, head_dim)`. Each CTA writes its head's output. The
|
||||
- hd=128, n_h=8 — PASS
|
||||
- hd=256, n_h=2 — PASS
|
||||
|
||||
- [x] **D2 Head-Packed:** Q reshaped to (n_h*T, hd, 1), per-row softmax
|
||||
- cos 0.999995 for n_h=1-128 at hd=64
|
||||
- Pro decode (n_h=128, T=1): M=128, one CTA processes all 128 heads
|
||||
|
||||
### 🟡 Blocked (Multi-CTA Grid)
|
||||
|
||||
- [ ] **D2.1:** Add `num_query_heads` and `batch_size` to `FmhaKernel.__init__`
|
||||
- Simple to add, but the grid change is blocked (see below)
|
||||
- Added as constructor params, but grid still (1,1,batch) — per-head launch in Python
|
||||
|
||||
- [ ] **D2.3–D2.6:** Multi-CTA grid with runtime block coordinates
|
||||
- **BLOCKED:** `cute.local_tile` does not support runtime coordinates. Must use `cute.flat_divide` instead.
|
||||
@@ -216,9 +220,8 @@ O has shape `(batch, n_h, T, head_dim)`. Each CTA writes its head's output. The
|
||||
- **CUTLASS reference approach:** Uses `flat_divide` + `tma_partition` inside the TMA warp block, and a custom epilogue that handles the flat_divide coordinate system. Estimated 1-2 day effort.
|
||||
|
||||
- [ ] **D2.9:** LSE output for multi-head
|
||||
- Per-row LSE verified correct (max err 0.000001) but CuTe tensor indexing needs work
|
||||
- Currently only row 0 is written (sfw_idx==0 guard)
|
||||
- Full per-row output needed for D5 KV merge
|
||||
- Per-row LSE verified correct (max err 0.000001), all 128 rows now write
|
||||
- row_sums output also working — O_norm = O_unnorm / row_sum
|
||||
|
||||
---
|
||||
|
||||
|
||||
Reference in New Issue
Block a user