diff --git a/= b/= deleted file mode 100644 index e69de29b..00000000 diff --git a/MAY_24_26_PLAN.md b/MAY_24_26_PLAN.md index c6a2d240..557df3b1 100644 --- a/MAY_24_26_PLAN.md +++ b/MAY_24_26_PLAN.md @@ -1,92 +1,96 @@ -# May 24, 2026 — Session Start Plan +# May 24–26, 2026 — Session Plan & Progress ## Quick Context -You're working on the DSV4 (DeepSeek V4 Pro) NVFP4 inference kernel for Blackwell B200. The FMHA (Fused Multi-Head Attention) kernel is working at hd=64/128/256 (cos 0.999998). The next milestones are: fix O rescale for multi-KV-tile, add multi-head grid (D2), and verify NVFP4 primitives. +You're working on the DSV4 (DeepSeek V4 Pro) NVFP4 inference kernel for Blackwell B200. The FMHA (Fused Multi-Head Attention) kernel is working at hd=64/128/256 (cos 0.999998). D5 (sink merge) is COMPLETE. The next milestones are: fix D1.5 O rescale, then proceed to production extraction (Stage E). **B200:** See MEMORY.md for access (not committed to repo) **Repo:** `git@sweetapi.com:2222/biondizzle/nvfp4-megamoe-kernel.git` **Local:** `~/dev/nvfp4-megamoe-kernel` **Test command:** `~/.openclaw/workspace/fire_b200_test ` -## ⚡ Execute in This Order +## ✅ Completed (May 24–26) -### 1. NVFP4-0: Verify FP4 Primitives (20 min, NO CODE CHANGES) +### 1. NVFP4-0: Verify FP4 Primitives ✅ +- All four diagnostics PASS — sf_dtype, TMA element type, MMA kind all correct +- NVFP4 uses FP8 E4M3 scales (NOT UE8M0), 16-element blocks, tcgen05 MMA kind correct -These are **print-only diagnostics**. If any reveal a wrong dtype, stop and fix it before everything else. +### 2. NVFP4-3: use_2cta_instrs Conditional ✅ +- tokens_sum >= 256 and cluster_m even → 2-CTA UMMA +- 1.7-1.9× throughput at prefill shapes. Decode stays 1-CTA. -- **NVFP4-0.1** — Trace `sf_dtype` through `gemm_runner.py` → `dense.py` → `blockscaled_utils`. NVFP4 uses FP8 E4M3 scales (NOT UE8M0 which is MXFP4). If the runner is passing E8M0, every FP4 GEMM is wrong. -- **NVFP4-0.2** — Verify SF TMEM layout is UE4M3 packed (4 FP8 E4M3 per int32), NOT UE8M0 (MXFP8). -- **NVFP4-0.3** — Verify `float4_e2m1fn_x2` survives into TMA descriptors (not downcast to uint8). -- **NVFP4-0.4** — Verify tcgen05 MMA kind resolves to NVFP4 (16-element blocks, E4M3 scales), not MXFP4 (32-element, UE8M0). +### 3. D1: Parameterized HEAD_DIM ✅ (hd≤256) +- hd=64/128/256: cos 0.999998, LSE err 0.0 +- hd=512: BLOCKED by MLIR compilation hang (>3hr). External k_sub merge impossible. +- D1.5 O rescale: BLOCKED by TMEM round-trip corruption. Python KV merge workaround. -**How:** Add `print()` calls in the Python layer, run any FP4 GEMM test, check output. Remove prints after. +### 4. D2: Multi-Query Grid ✅ (per-head launch) +- Head-packed M-dimension: Q reshaped to (n_h*T, hd, 1), per-row softmax +- cos 0.999995 for n_h=1-128 at hd=64, n_h=2-8 at hd=128, n_h=2 at hd=256 +- Multi-CTA grid: BLOCKED by flat_divide + epilogue_tma_store mismatch -### 2. Test O Rescale at s_k > 128 (30 min) +### 5. D3: SWA Sequence Length Mask ✅ +- In-kernel post-QK masking via tTMEM_LOADcS coordinates +- swa_len as Int32 scalar, offset by n_comp for D5c -**The problem:** The O rescale code (for multi-KV-tile, kt>0) is guarded away with `const_expr(n_kv_tiles > 1)` at n=128. It uses hand-constructed TMEM atoms. **Untested and likely broken.** +### 6. D4: Causal Mask ✅ +- SWA-relative position (kv_pos - n_comp) > m_coord → -inf +- Combined with D3 via OR logic -**Why it matters NOW:** DSV4 Pro uses top_k=1024 → s_k=1024 → n_kv_tiles=8. D2 multi-head will exercise s_k>128. If rescale is broken, all D2 production tests fail. +### 7. D5: SWA + Sink Merge ✅ +- D5a: normalize flag + LSE + row_sums output +- D5b: Per-row LSE + Python KV merge (cos 0.999994) +- D5c: Sink bias as logit modification — **key insight: sink merge = single softmax over [S_comp, S_swa + attn_sink]** + - Single-tile: cos 0.999996 + - Multi-tile (Python KV merge): cos 0.999996 + - D5d NOT NEEDED — sink bias approach supersedes fused merge epilogue -**How to test:** -1. Create `test_d1_multi_kv.py` with `FmhaKernel(head_dim=64, s_k=256, normalize=False)` (2 KV tiles) -2. Run it on B200 -3. If cos < 0.99, O rescale is broken → fix before D2 -4. If cos ~0.999, rescale works → proceed to D2 +## 🎯 Next Priorities (in order) -**If broken, fix approach:** Replace hand-constructed TMEM round-trip with CUTLASS `correction_rescale_and_partition` pattern (one-way TMEM→SMEM). See STAGE_D.md D1.5 Issue 2. +### Priority 1: D1.5 — Fix O Rescale for Multi-KV-Tile (BLOCKER) +**Why:** Production DSV4 Pro decode needs s_k=1152 (9 KV tiles). Python KV merge works but requires 5-9 kernel launches per decode step. -### 3. Start D2: Multi-Query Grid (main work) +**Approaches (ordered by feasibility):** +1. **Correction epilog pattern** (1-2 days): One-way TMEM→regs→SMEM→GMEM. Study CUTLASS reference `correction_rescale_and_partition` + `epilogue_tmem_copy_and_partition`. This is the proper Blackwell pipeline. +2. **Skip TMEM round-trip entirely**: After softmax, write P to SMEM, PV accumulate to SMEM, TMA store from SMEM. Requires SMEM budget for both P and O. +3. **Python KV merge as production path**: Accept the multi-launch overhead. Profile to see if it's actually a bottleneck (5-9 launches × ~50μs ≈ 250-450μs, vs ~50μs for single launch with O rescale). -See `STAGE_D2.md` for the full plan. Summary: +### Priority 2: Stage E — Production Extraction +**Why:** D5 is complete. The kernel works. Time to wrap it in a proper interface. -- Add `num_query_heads` to `FmhaKernel` constructor -- Change grid from `(1,1,1)` to `(ceil_div(T, 128), num_query_heads, batch)` -- Map `block_idx` → `(m_tile, head_idx, batch_idx)` inside kernel -- Q TMA indexed per-head, K/V shared (MQA) -- Test with n_h=2 → n_h=8 → n_h=64/128 +- E1: File placement (already done — `dsv4/kernels/attention/fmha.py`) +- E2: Constructor signature (partially done — needs cleanup) +- E3: Call signature (needs sink_bias, row_sums, n_comp integration) +- E4: Kernel cache + warmup (key on n_comp, apply_sink_bias, head_dim, s_k) +- E5: torch.library custom op +- E6: Reference parity test +- E7: Cleanup (delete debug test files) -**First step:** Create `test_d2_multihead.py` with n_h=1 regression test (verify nothing breaks), then n_h=2. +### Priority 3: NVFP4-1.1 — Fuse FP4 into SwiGLU Epilogue (1 day, parallel) +**Why:** Biggest bandwidth win for MoE pipeline. No FMHA dependency. Can work in parallel with D1.5. -### 4. NVFP4-3: use_2cta_instrs Conditional (30 min, parallel) +Current: L1 GEMM → SwiGLU → BF16 GMEM → quantize → FP4 GMEM → L2 GEMM +Target: L1 GEMM → SwiGLU → FP4 pack in registers → FP4 GMEM → L2 GEMM -Pure perf win for MoE GEMMs. Add `use_2cta_instrs = (M >= 256 and cluster_m % 2 == 0)` in `gemm_runner.py`. 1.7–1.9× throughput at prefill shapes. No FMHA dependency. +### Priority 4: hd=512 Fix (BLOCKED by MLIR) +**Status:** Kernel structurally correct (tracer 0.8s). MLIR optimizer hangs for 3+ hours. +**Options:** +1. Pre-compile offline + cache cubin (if MLIR eventually finishes) +2. Write hd=512 path in raw CUTLASS C++ (bypass CuTeDSL MLIR) +3. Report bug to NVIDIA -### 5. NVFP4-1.1: Fuse FP4 into SwiGLU Epilogue (1 day, parallel) - -Biggest bandwidth win. Current: L1 GEMM → SwiGLU → BF16 GMEM → quantize → FP4 GMEM → L2 GEMM. Target: L1 GEMM → SwiGLU → FP4 pack in registers → FP4 GMEM → L2 GEMM. Saves entire quantize kernel launch + 2× bandwidth. See STAGE_D.md for full spec. - ---- - -## File Map (what to read for context) - -| File | What it contains | -|------|-----------------| -| `STAGE_D.md` | Full FMHA kernel status, NVFP4 precision roadmap, D1.5 gaps | -| `STAGE_D2.md` | D2 multi-query grid plan with 9-item to-do list | -| `README.md` | Architecture, CuTeDSL constraints (#1–#16), test harness docs | -| `dsv4/kernels/attention/fmha.py` | The FMHA kernel (518 lines, FmhaKernel class) | -| `dsv4/model/config.py` | DSV4 dimensions: Flash n_h=64, Pro n_h=128, hd=512 | -| `dsv4/ops/decode_sparse.py` | Sink merge formula, MQA op interface | -| `MEMORY.md` | Long-term memory (B200 access, all stage results) | -| `memory/2026-05-24.md` | Today's daily log (hd=512 SMEM fix, MLIR hang, all bug fixes) | +### Priority 5: D2 Multi-CTA Grid (BLOCKED by flat_divide) +**Status:** Per-head launch works for decode. Multi-CTA needed for prefill. +**Requires:** Full tma_partition + epilogue refactor into kernel (1-2 day effort). ## Key Numbers | Config | n_h | top_k | s_k | n_kv_tiles | O rescale needed? | |--------|----:|------:|----:|-----------:|:------------------| -| Flash decode | 64 | 512 | 512 | 4 | YES | -| Pro decode | 128 | 1024 | 1024 | 8 | YES | +| Flash decode | 64 | 512 | 640 | 5 | YES | +| Pro decode | 128 | 1024 | 1152 | 9 | YES | | Current test | 1 | — | 128 | 1 | No (guarded away) | -## D1 Status Summary - -- ✅ hd=64/128/256: cos 0.999998, LSE err 0.0 -- ❌ hd=512: SMEM fits (192KB) but MLIR compilation hangs (3+ hours). External k_sub merge mathematically impossible. Need either: (a) pre-compile offline, (b) no-softmax mode for S accumulation, or (c) raw CUDA C++ kernel. -- ⚠️ O rescale (kt>0): untested for s_k>128, likely broken -- ✅ D5a (un-normalized O + LSE): done -- ✅ D5b (Python sink merge): done, cos 0.961 - ## Rules (don't forget) - NEVER edit on B200. Edit locally → commit → push → pull → test. diff --git a/README.md b/README.md index 54f1209b..34624227 100644 --- a/README.md +++ b/README.md @@ -138,7 +138,7 @@ Summary --- -## Status (May 25, 2026 — 01:10 UTC) +## Status (May 26, 2026 — 18:40 UTC) | Stage | Status | Description | |-------|--------|-------------| @@ -146,10 +146,11 @@ Summary | B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) | | C | ✅ COMPLETE | Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. | | D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. O rescale TMEM round-trip BROKEN (Ld32x32bOp/St32x32bOp corrupt data). Python KV merge workaround works. | -| D2 | 🟡 Per-head DONE | Multi-query grid. Per-head launch works (cos 0.999998, n_h=1-64 hd=64, n_h=2-8 hd=128, n_h=2 hd=256). Multi-CTA grid blocked: `flat_divide` + `epilogue_tma_store` layout mismatch. Requires full tma_partition refactor into kernel. | -| D3 | ✅ DONE | SWA sequence length mask (in-kernel post-QK via tTMEM_LOADcS coordinates, swa_len Int32 scalar) | -| D4 | ✅ DONE | Causal mask on SWA branch (k_coord > m_coord → -inf, combined with D3 via OR logic) | -| D5 | 🟢 D5a+D5b+D5c DONE | D5a: normalize flag + LSE output. D5b: Per-row LSE + Python KV merge (cos 0.999994). D5c: Sink bias (attn_sink) as logit modification in combined KV (cos 0.999996, single KV tile). Multi-tile blocked by D1.5. | +| D1.5 | ❌ BLOCKER | O rescale for multi-KV-tile (kt>0). TMEM round-trip corruption (even NO-OP round-trip fails). Python KV merge workaround: cos 0.999994. Production: 5-9 kernel launches per decode. Fix requires correction epilog (one-way TMEM→regs→SMEM→GMEM). | +| D2 | 🟡 Per-head DONE | Head-packed M-dimension launch (cos 0.999995, n_h=1-128). Multi-CTA grid blocked: `flat_divide` + `epilogue_tma_store` layout mismatch. | +| D3 | ✅ DONE | SWA sequence length mask (in-kernel post-QK via tTMEM_LOADcS coordinates, swa_len Int32 scalar, offset by n_comp for D5c) | +| D4 | ✅ DONE | Causal mask on SWA branch (SWA-relative position > m_coord → -inf, combined with D3 via OR logic) | +| D5 | ✅ D5a+D5b+D5c DONE | D5a: normalize flag + LSE + row_sums output. D5b: Per-row LSE + Python KV merge (cos 0.999994). D5c: Sink bias as logit modification — mathematically equivalent to separate merge, single pass over combined KV (cos 0.999996 single-tile AND multi-tile). D5d (fused in-kernel merge) NOT NEEDED — sink bias approach supersedes it. | | E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) | | NVFP4-3 | ✅ DONE | `use_2cta_instrs` conditional in gemm_runner.py. 1.7-1.9× throughput at prefill shapes. | @@ -231,7 +232,9 @@ dsv4/ | `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 | | `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** | | `test_fmha_v3_stage_d1.py` | D1 | ✅ hd=64/128/256 PASS (cos 0.999998, TMEM-P). hd=512 SMEM overflow. | -| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.961, LSE err=0.0) | +| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.999994, LSE err=0.0) | +| `test_d5c_fused.py` | D5c | ✅ Single-tile combined KV + sink bias (cos 0.999996) | +| `test_d5c_multitile.py` | D5c | ✅ Multi-tile with Python KV merge + sink bias (cos 0.999996) | | `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) | | `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments | | `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline | @@ -394,6 +397,9 @@ Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16)) 18. **KV merge formula uses NORMALIZED outputs, not un-normalized.** The correct D5 merge for different token sets: `O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]`. Using `O_i_unnorm` instead of `O_i_norm` gives cos ~0.91. The un-norm merge only works when both segments share the same `row_max` (global max), which isn't the case for separate KV segments. 19. **`flat_divide` + `epilogue_tma_store` layout mismatch.** When using `cute.flat_divide` to create per-CTA GMEM views with runtime block coordinates (for multi-CTA grid), the resulting tensor layout is incompatible with CUTLASS's `epilogue_tma_store` pipeline, which expects the layout from `local_tile`. The tma_partition and epilogue must be refactored together to support multi-CTA grids. 20. **`local_tile` does not support runtime coordinates.** `cute.local_tile(mQ, tiler, (runtime_val, None))` fails at trace time. Must use `cute.flat_divide(mQ, tiler)` instead, which creates a tiled view with all rest dimensions accessible via runtime indexing. +21. **Sink bias domain correction.** Adding `attn_sink` directly to raw logits is wrong — it gets scaled by `scale_log2`. Fix: add `attn_sink / scale` to raw logits, so after `* scale_log2` it becomes `attn_sink * log2(e)`, correctly multiplying attention weights by `exp(attn_sink)`. +22. **O normalization uses row_sum, NOT LSE.** `O_norm = O_unnorm / row_sum` is correct. `O_unnorm * exp(-LSE)` is WRONG because O_unnorm is max-shifted (divided by `2^row_max`), not raw `exp(S) @ V`. The kernel now outputs `row_sum` alongside LSE. +23. **n_comp is compile-time, swa_len is runtime.** The `n_comp` parameter controls `const_expr` guards in the kernel and cannot vary between segments of the same kernel instance. `swa_len` is an `Int32` scalar and can vary per request. For multi-tile production, use a kernel cache keyed on `(n_comp, apply_sink_bias, head_dim, s_k)`. --- @@ -459,7 +465,7 @@ The SWA branch needs a causal mask within the window. Add `is_causal: bool` cons Done when: prefill mode produces correct output with the causal mask applied to SWA. -**D5 — SWA + sink merge** (~2-3 days) ← D5a+D5b DONE (May 23), D5c/D5d remaining +**D5 — SWA + sink merge** ✅ COMPLETE (May 26) Per `dsv4/ops/decode_sparse.py`: ``` @@ -467,41 +473,36 @@ o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa) / (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)) ``` -With un-normalized O (D5a): `o_unnorm = o_norm * exp(lse)`, so: +**Key insight (May 26):** This merge is mathematically identical to a single attention pass over concatenated KV with a logit bias on SWA positions: ``` -o = (o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa) - / (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)) +S = [S_comp, S_swa + attn_sink] +O = softmax(S) @ [V_comp; V_swa] ``` +This means D5c is a **logit bias addition**, not a two-pass + merge kernel. D5d (fused in-kernel merge epilogue) is NO LONGER NEEDED. -**D5a DONE (May 23):** `normalize` flag added to FmhaKernel. When False, emits un-normalized O + LSE. LSE formula: `lse = ln(row_sum) + row_max * ln(2)` (row_max in scale_log2 domain, multiply by ln(2) to convert). LSE err=0.000000 verified. +**D5a ✅:** `normalize` flag + LSE + row_sums output. When False, emits un-normalized O + LSE + row_sum. -**D5b DONE (May 23):** Python SWA+sink merge works end-to-end at hd=64. Run FMHA twice (compressed KV + SWA KV, normalize=False), merge in Python. Merge cos 0.961, individual attention cos 0.963/0.960. +**D5b ✅:** Per-row LSE output (all 128 rows now write). Python KV merge with per-row LSE: cos 0.999994. -Sub-steps remaining: -- **5c:** Fuse the two passes into one kernel launch. Q stays in SMEM, two MMA loops sequentially. -- **5d:** Fuse the merge into the kernel epilogue. +**D5c ✅:** Sink bias as logit modification. Parameters: `n_comp` (compressed KV length, compile-time), `apply_sink_bias` (compile-time flag), `sink_bias` (runtime FP32 tensor). Sink bias added to raw logits as `attn_sink / scale` so after `* scale_log2` it correctly becomes `attn_sink * log2(e)` in the exp2 domain. Multi-tile via Python KV merge: cos 0.999996. -Done when: end-to-end kernel produces correct attention against FP32 oracle that does sparse+SWA+sink merge. +**D5d:** NOT NEEDED. The sink bias approach makes a fused merge epilogue unnecessary. + +Done when: ✅ End-to-end kernel produces correct attention against FP32 oracle. **~~D5 (old) paged TMA~~ — REMOVED.** The indexer + gather handles all paging upstream. -### Kernel Architecture (after D5) +### Kernel Architecture (after D5 — COMPLETE) ``` Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512] swa_lens [batch], sink_logits [n_h], request_ids [T] │ - ├─ Load Q to SMEM (once) - │ - ├─ Loop 1: compressed KV (top_k tokens) - │ QK → online softmax → PV → O_sparse, lse_sparse in TMEM - │ - ├─ Loop 2: SWA window (n_win tokens, masked by swa_lens) - │ QK → online softmax → PV → O_swa, lse_swa in TMEM - │ - └─ Sink merge epilogue: - O = (exp(lse_sparse) * O_sparse + exp(sink) * exp(lse_swa) * O_swa) - / (exp(lse_sparse) + exp(sink) * exp(lse_swa)) + └─ Single pass over concatenated KV [compressed_kv; swa_kv]: + QK → online softmax (with sink bias on SWA, D3/D4 masking) → PV + → O_unnorm + LSE + row_sum + → External normalize: O = O_unnorm / row_sum + → (Multi-tile: Python KV merge across 128-token segments) ``` ### Reference Files diff --git a/STAGE_D.md b/STAGE_D.md index 87395dc9..5737d945 100644 --- a/STAGE_D.md +++ b/STAGE_D.md @@ -31,7 +31,7 @@ --- -## Current Status (2026-05-24, 21:30 UTC) +## Current Status (2026-05-26, 18:40 UTC) ### ✅ WORKING @@ -41,12 +41,21 @@ | 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | 128KB | | 256 | 0.999998 | 0.000000 | TMEM-P | 224KB | +### ✅ D5 COMPLETE (May 26) + +| Test | Config | cos | Status | +|------|--------|-----|--------| +| D5c single-tile | n_comp=64, n_swa=64, sink=0.5 | 0.999996 | ✅ | +| D5c causal | n_comp=64, n_swa=64, sink=0.3, causal | 0.999996 | ✅ | +| D5c multi-tile | n_comp=96, n_swa=160, s_k=256, Python KV merge | 0.999996 | ✅ | +| D3 regression | In-kernel mask, s_k=128 | 0.999996 | ✅ | +| D4 regression | Causal mask, s_k=128 | 0.999996 | ✅ | + ### ❌ KNOWN ISSUES -- **hd=512: MLIR compilation hangs.** SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). Both `range()` unrolled and `cutlass.range(unroll=1)` runtime loops trigger this. This is a CuTeDSL/MLIR toolchain limitation. -- **External k_sub merge doesn't work.** k_sub segments are additive in logit space (S = S_0 + S_1), not attention weight space. The D5 merge formula does not apply. In-kernel k_sub accumulation is the only correct approach. -- **O rescale (kt>0):** Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. Guarded with `const_expr(n_kv_tiles > 1)`. -- **Kernel always outputs un-normalized O + LSE.** No in-kernel normalization (eliminates TMEM round-trip error). External normalization: `O_norm = O_unnorm / row_sum`. +- **hd=512: MLIR compilation hangs.** SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). This is a CuTeDSL/MLIR toolchain limitation. +- **D1.5 O rescale (multi-KV-tile): TMEM round-trip corruption.** Hand-constructed Ld32x32bOp/St32x32bOp atoms corrupt data on round-trip (even NO-OP). Workaround: Python KV merge (cos 0.999994). Fix requires correction epilog pattern (one-way TMEM→regs→SMEM→GMEM). +- **D2 multi-CTA grid: flat_divide + epilogue_tma_store layout mismatch.** Requires full tma_partition refactor into kernel. Head-packed per-head launch works (cos 0.999995). --- @@ -198,28 +207,32 @@ pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit S 3. **Write hd=512 kernel in CUTLASS C++.** Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds). 4. **Report CuTeDSL MLIR optimizer bug.** The optimizer should handle this IR in reasonable time. File an issue with NVIDIA. -### D2 — Multi-Query Grid with Head Packing +### D2 — Multi-Query Grid with Head Packing ✅ (per-head launch) -- Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)` -- DSV4 is MQA: all 128 query heads share same K/V -- Head axis folded into M dimension: `M_tile = 128` covers `M = T * n_h` rows +- Head-packed M-dimension launch: Q reshaped to (n_h*T, hd, 1), kernel treats each row independently +- cos 0.999995 for n_h=1-128 at hd=64, n_h=2-8 at hd=128, n_h=2 at hd=256 +- Multi-CTA grid (flat_divide) BLOCKED — see Known Issues -### D3 — SWA Sequence Length Mask +### D3 — SWA Sequence Length Mask ✅ -- Add `swa_lens: [batch] int32` kernel input -- Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]` +- In-kernel post-QK masking via tTMEM_LOADcS coordinates +- swa_len as Int32 scalar (runtime, not compile-time) +- Offset by n_comp for D5c: mask positions >= n_comp + swa_len -### D4 — Causal Mask on SWA Branch +### D4 — Causal Mask on SWA Branch ✅ -- Add `is_causal: bool` constructor flag -- Apply `swa_idx > q_pos` masking to `-inf` in SWA pass +- SWA-relative position (kv_pos - n_comp) > m_coord → -inf +- Combined with D3 via OR logic -### D5 — SWA + Sink Merge +### D5 — SWA + Sink Merge ✅ (May 26) -- **D5a ✅:** Kernel outputs un-normalized O + LSE -- **D5b ✅:** Python merge works (cos 0.961 at hd=64) -- **D5c:** Fuse two passes into one kernel launch -- **D5d:** Fuse sink merge into kernel epilogue +**Key insight:** Sink merge = single softmax over [S_comp, S_swa + attn_sink]. +One pass, one kernel. D5d NOT NEEDED. + +- **D5a ✅:** normalize flag + LSE + row_sums output +- **D5b ✅:** Per-row LSE + Python KV merge (cos 0.999994) +- **D5c ✅:** Sink bias as logit modification (cos 0.999996 single-tile AND multi-tile) +- **D5d:** NOT NEEDED — sink bias approach supersedes fused merge epilogue --- diff --git a/STAGE_D2.md b/STAGE_D2.md index 8bfe514c..cb0cbbc5 100644 --- a/STAGE_D2.md +++ b/STAGE_D2.md @@ -204,10 +204,14 @@ O has shape `(batch, n_h, T, head_dim)`. Each CTA writes its head's output. The - hd=128, n_h=8 — PASS - hd=256, n_h=2 — PASS +- [x] **D2 Head-Packed:** Q reshaped to (n_h*T, hd, 1), per-row softmax + - cos 0.999995 for n_h=1-128 at hd=64 + - Pro decode (n_h=128, T=1): M=128, one CTA processes all 128 heads + ### 🟡 Blocked (Multi-CTA Grid) - [ ] **D2.1:** Add `num_query_heads` and `batch_size` to `FmhaKernel.__init__` - - Simple to add, but the grid change is blocked (see below) + - Added as constructor params, but grid still (1,1,batch) — per-head launch in Python - [ ] **D2.3–D2.6:** Multi-CTA grid with runtime block coordinates - **BLOCKED:** `cute.local_tile` does not support runtime coordinates. Must use `cute.flat_divide` instead. @@ -216,9 +220,8 @@ O has shape `(batch, n_h, T, head_dim)`. Each CTA writes its head's output. The - **CUTLASS reference approach:** Uses `flat_divide` + `tma_partition` inside the TMA warp block, and a custom epilogue that handles the flat_divide coordinate system. Estimated 1-2 day effort. - [ ] **D2.9:** LSE output for multi-head - - Per-row LSE verified correct (max err 0.000001) but CuTe tensor indexing needs work - - Currently only row 0 is written (sfw_idx==0 guard) - - Full per-row output needed for D5 KV merge + - Per-row LSE verified correct (max err 0.000001), all 128 rows now write + - row_sums output also working — O_norm = O_unnorm / row_sum ---