diff --git a/README.md b/README.md index 4e8a679d..7fde0c5a 100644 --- a/README.md +++ b/README.md @@ -145,11 +145,11 @@ Summary | A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM | | B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) | | C | ✅ MIGRATED TO MODULE | Real online softmax + normalize. n=128 cos 0.973. Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. TMEM layout mismatch still present (3% error). | -| D1 | 🔨 IN PROGRESS | Parameterize HEAD_DIM (64 → 512). SMEM-P path for hd>64 (register→SMEM copy TODO). | +| D1 | 🟡 MOSTLY DONE | Parameterized HEAD_DIM. TMEM-P hd=64 works (cos 0.973). SMEM-P for hd>64 is a stub (make_tiled_copy_C rank mismatch). tOrP0 TMEM column offset bug fixed. | | D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) | | D3 | TODO | SWA sequence length mask (swa_lens per batch) | | D4 | TODO | Causal mask on SWA branch only | -| D5 | TODO | SWA + sink merge (two MMA loops, log-sum-exp merge) | +| D5 | 🟢 D5a+D5b DONE | D5a: normalize flag + LSE output (err=0.0). D5b: Python SWA+sink merge (cos 0.961). D5c/D5d: fused kernel merge TODO. | | E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) | --- @@ -229,7 +229,8 @@ dsv4/ | `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 | | `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 | | `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** | -| `test_fmha_v3_stage_d1.py` | D1 | 🔨 Parameterized hd + SMEM-P path (WIP) | +| `test_fmha_v3_stage_d1.py` | D1 | 🟡 Parameterized hd, hd=64 PASS (cos 0.973), hd>64 FAIL (SMEM-P stub) | +| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.961, LSE err=0.0) | | `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) | | `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments | | `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline | @@ -328,44 +329,22 @@ What it does: ### Root Cause: TMEM Layout Mismatch -The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses a **different TMEM column mapping**. When `epilogue_tma_store` reads O directly after PV (no normalize), the layout matches perfectly — **cos 0.999998** (raw PV output is correct). +The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses the **correct** C-fragment-compatible layout. **Raw PV output is perfect (cos 0.999998)** when `epilogue_tma_store` reads directly without any round-trip. -The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms: +The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error. -1. **NO-OP round-trip (load + store unchanged) → cos 0.973.** The hand-constructed atoms read/write using a different column mapping than `get_tmem_load_op`. The data gets "transcoded" — close but not exact. -2. **Normalize round-trip (load → multiply by 1/row_sum → store) → cos 0.973** (with preceding NO-OP) or **cos 0.465** (without NO-OP). Without the NO-OP, `epilogue_tma_store` reads the MMA layout directly and produces garbage. -3. **O rescale (kt > 0) + normalize → cos 0.793** at n=256. Each round-trip compounds the layout mismatch error. +**Fix proven but not yet integrated:** The `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` pattern from CUTLASS's `cutlass.utils.gemm.sm100` reads O from TMEM using the correct `get_tmem_load_op` layout and writes to SMEM using `get_smem_store_op`. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper `flat_divide` and `tma_partition` handling inside the kernel's warp-specific if blocks. -### Why the NO-OP Round-Trip "Fixes" It - -The MMA writes O in C-fragment TMEM layout. `epilogue_tma_store` reads in `get_tmem_load_op` layout. Without a round-trip, these layouts are incompatible → garbage output (cos 0.465). - -A NO-OP round-trip through hand-constructed atoms reads the data using the hand-constructed layout (which can read the C-fragment data) and writes it back using the hand-constructed layout. After the round-trip, the data is in the hand-constructed layout, which is close to (but not identical to) the `get_tmem_load_op` layout → 3% error (cos 0.973). - -### The Proper Fix: correction_epilog Pattern - -The CUTLASS FMHA reference uses a one-way trip for the final epilogue: +### Key Bug Fix: tOrP0 TMEM Column Offset (May 23) +The softmax warps store P at `tmem_p0_offset=32` FP32 columns (64 BF16 elements). PV MMA must read from the same offset. **`tOrP0` was missing this offset**, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with: +```python +if const_expr(self.tOrP0_offset > 0): + tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) +else: + tOrP0 = tOrP ``` -TMEM --get_tmem_load_op--> reg (normalize + FP32→BF16) --get_smem_store_op--> SMEM --TMA--> GMEM -``` - -This reads O using `get_tmem_load_op` (same layout as `epilogue_tma_store`) and writes directly to SMEM. No TMEM round-trip. No layout mismatch. The CUTLASS reference uses this pattern and gets correct results. - -**Why we can't use it yet:** The TMA store from SMEM → GMEM requires `tma_partition` / `flat_divide` which hit CuTeDSL region isolation errors when called inside `if warp_idx < self.mma_warp_id` blocks. The `epilogue_tma_store` helper works because it's a regular Python function that inlines into the same MLIR region — but it always reads from TMEM, not SMEM. - -**Possible solutions:** -1. Call `epilogue_tma_store` but inject the `1/row_sum` multiply into its pipeline (requires modifying the helper or replicating it inline with the scale) -2. Pre-compute TMA partitioning outside the `if` block and pass the partitioned tensors through the kernel interface -3. Use the experimental `cutlass.cute.experimental.epilogue_tma_store` API which has a cleaner structure - -### Verified Facts - -- **Raw PV output is perfect:** `epilogue_tma_store` with identity op gives cos 0.999998 at n=128 -- **Softmax P values are correct:** Unnormalized P@V matches reference exactly (cos 0.999998) -- **Online softmax computation is correct:** row_max and row_sum tracking works -- **The ONLY issue is the TMEM round-trip for normalize/rescale** -- **Stage A/B with identity softmax: cos 0.999999** — the pipeline works, softmax is the only addition +Must use `const_expr` conditional (not Python `if`) because CuTeDSL compiles both branches, and `tOrP.iterator + 0` fails with MLIR type error. ### Architecture (6-warp, current) @@ -404,6 +383,8 @@ Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16)) 8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions. 9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with. 10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching. +11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails). +12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000. --- @@ -471,7 +452,7 @@ The SWA branch needs a causal mask within the window. Add `is_causal: bool` cons Done when: prefill mode produces correct output with the causal mask applied to SWA. -**D5 — SWA + sink merge** (~2-3 days) ← the real new work +**D5 — SWA + sink merge** (~2-3 days) ← D5a+D5b DONE (May 23), D5c/D5d remaining Per `dsv4/ops/decode_sparse.py`: ``` @@ -479,16 +460,17 @@ o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa) / (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)) ``` -Both branches must produce **un-normalized `o`** AND **`lse = log(row_sum) + row_max`**. Stage C currently normalizes by `1/row_sum` in the epilogue — that must change. +With un-normalized O (D5a): `o_unnorm = o_norm * exp(lse)`, so: +``` +o = (o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa) + / (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa)) +``` -Three structural changes: -1. **TMEM grows:** two O accumulators + two row_max/row_sum per row. Roughly doubles TMEM column usage. Verify it fits. -2. **Q is loaded once, used twice.** Free win — largest input stays in SMEM. -3. **K and V have two sources.** Compressed K/V from contiguous BF16 input. SWA K/V from state cache — for now, dequantize SWA in a small prep kernel before FMHA, let FMHA see two contiguous BF16 sources. Optimize later. +**D5a DONE (May 23):** `normalize` flag added to FmhaKernel. When False, emits un-normalized O + LSE. LSE formula: `lse = ln(row_sum) + row_max * ln(2)` (row_max in scale_log2 domain, multiply by ln(2) to convert). LSE err=0.000000 verified. -Sub-steps: -- **5a:** Modify Stage C to emit un-normalized `o` + `lse` instead of normalized `o`. (Keep normalize as a flag so standalone tests still work.) -- **5b:** Run kernel twice externally (once with compressed_kv, once with swa_kv), merge results in Python. **End-to-end correctness without touching kernel structure.** Hours of work. +**D5b DONE (May 23):** Python SWA+sink merge works end-to-end at hd=64. Run FMHA twice (compressed KV + SWA KV, normalize=False), merge in Python. Merge cos 0.961, individual attention cos 0.963/0.960. + +Sub-steps remaining: - **5c:** Fuse the two passes into one kernel launch. Q stays in SMEM, two MMA loops sequentially. - **5d:** Fuse the merge into the kernel epilogue. diff --git a/STAGE_D1.3.md b/STAGE_D1.3.md index 21a1c03d..03af8abc 100644 --- a/STAGE_D1.3.md +++ b/STAGE_D1.3.md @@ -164,4 +164,25 @@ else: 3. Examine PV MMA tiler and SMEM layout generation 4. Consider alternative: fix TMEM layout generation instead -**Time spent:** ~45 minutes. Have working but incorrect SMEM-P. \ No newline at end of file +## Progress Update (2026-05-23 21:30 UTC) + +**TMEM-P path now works at hd=64 (cos 0.973).** The root cause of NaN/zeros was a missing TMEM column offset on `tOrP0` — PV MMA was reading from column 0 (where S is) instead of column 32 (where P is stored by softmax warps). Fixed with `const_expr` conditional. + +**SMEM-P remains unsolved.** The `make_tiled_copy_C` approach gives rank mismatch. Manual coordinate mapping compiles but produces near-zero cosine (wrong addresses). The CUTLASS reference FMHA uses TMEM-P exclusively (12-warp layout with more TMEM budget). For our 6-warp layout, SMEM-P is needed for hd>64. + +**Current D1 status:** +- hd=64 (TMEM-P): cos 0.973 ✅ +- hd=256 (SMEM-P stub): FAIL (zeros) +- hd=512 (SMEM-P stub): FAIL (zeros) + +**Workaround for hd>64:** The D5b milestone (Python SWA+sink merge) works at hd=64. SMEM-P for hd>64 is a production optimization, not a correctness blocker. The full DSV4 pipeline (CSA + HCA + SWA) can be tested at hd=64 with TMEM-P. + +**Key discoveries:** +1. `make_tiled_copy_C(store_atom, qk_mma)` creates a copy that partitions threads by QK C-fragment layout, but the source and destination have incompatible ranks (4 vs 3). This is a fundamental layout incompatibility. +2. Manual coordinate mapping (`qk_to_pv_coord`) compiles and runs but produces wrong results. The mapping formula may be incorrect, or SMEM swizzle may interfere with tensor indexing. +3. The CUTLASS reference FMHA (12-warp) avoids SMEM-P entirely by using TMEM-P with more warps (more TMEM budget). A 12-warp layout would solve the SMEM-P problem architecturally. + +**Recommended next steps for SMEM-P:** +1. Try 12-warp layout (like CUTLASS reference) to avoid SMEM-P entirely +2. OR: Use `make_tiled_copy` with `pv_mma` (not `qk_mma`) for the copy, since PV MMA knows the SMEM layout +3. OR: Implement a two-stage copy: QK C-fragment → intermediate buffer → PV A-operand SMEM \ No newline at end of file