Update all .md files with D5a/D5b progress, tOrP0 fix, LSE formula

- README.md: Updated Stage status table (D1 🟡, D5 🟢), D5 section with
  D5a/D5b results, tOrP0 bug fix docs, new CuTeDSL constraints #11-12
- STAGE_D1.3.md: Added progress update - TMEM-P works, SMEM-P still blocked,
  recommended next steps
- STAGE_D.md was already updated
This commit is contained in:
2026-05-23 22:07:53 +00:00
parent 0fa1189937
commit fb9e2c0346
2 changed files with 49 additions and 46 deletions

View File

@@ -145,11 +145,11 @@ Summary
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ✅ MIGRATED TO MODULE | Real online softmax + normalize. n=128 cos 0.973. Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. TMEM layout mismatch still present (3% error). |
| D1 | 🔨 IN PROGRESS | Parameterize HEAD_DIM (64 → 512). SMEM-P path for hd>64 (register→SMEM copy TODO). |
| D1 | 🟡 MOSTLY DONE | Parameterized HEAD_DIM. TMEM-P hd=64 works (cos 0.973). SMEM-P for hd>64 is a stub (make_tiled_copy_C rank mismatch). tOrP0 TMEM column offset bug fixed. |
| D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) |
| D3 | TODO | SWA sequence length mask (swa_lens per batch) |
| D4 | TODO | Causal mask on SWA branch only |
| D5 | TODO | SWA + sink merge (two MMA loops, log-sum-exp merge) |
| D5 | 🟢 D5a+D5b DONE | D5a: normalize flag + LSE output (err=0.0). D5b: Python SWA+sink merge (cos 0.961). D5c/D5d: fused kernel merge TODO. |
| E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) |
---
@@ -229,7 +229,8 @@ dsv4/
| `test_fmha_v3.py` | A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
| `test_fmha_v3_12w.py` | A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
| `test_fmha_v3_stage_c.py` | C | ✅ Real online softmax + normalize, n=128 cos 0.973. **Also in module as `FmhaKernel`.** |
| `test_fmha_v3_stage_d1.py` | D1 | 🔨 Parameterized hd + SMEM-P path (WIP) |
| `test_fmha_v3_stage_d1.py` | D1 | 🟡 Parameterized hd, hd=64 PASS (cos 0.973), hd>64 FAIL (SMEM-P stub) |
| `test_fmha_v3_stage_d5b.py` | D5b | ✅ Python SWA+sink merge (cos 0.961, LSE err=0.0) |
| `test_d1_*.py` | D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
| `test_paired_epilog.py` | C | ✅ Paired atom epilogue experiments |
| `test_pv64_with_softmax.py` | B | ✅ (128,64) PV, single AB pipeline |
@@ -328,44 +329,22 @@ What it does:
### Root Cause: TMEM Layout Mismatch
The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses a **different TMEM column mapping**. When `epilogue_tma_store` reads O directly after PV (no normalize), the layout matches perfectly — **cos 0.999998** (raw PV output is correct).
The MMA instruction writes O to TMEM using the **C-fragment layout**. The `epilogue_tma_store` helper reads O from TMEM using `get_tmem_load_op`, which uses the **correct** C-fragment-compatible layout. **Raw PV output is perfect (cos 0.999998)** when `epilogue_tma_store` reads directly without any round-trip.
The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms:
The problem appears when we do a **TMEM round-trip** (load O → modify → store back) using hand-constructed `Ld32x32bOp/St32x32bOp` atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error.
1. **NO-OP round-trip (load + store unchanged) → cos 0.973.** The hand-constructed atoms read/write using a different column mapping than `get_tmem_load_op`. The data gets "transcoded" — close but not exact.
2. **Normalize round-trip (load → multiply by 1/row_sum → store) → cos 0.973** (with preceding NO-OP) or **cos 0.465** (without NO-OP). Without the NO-OP, `epilogue_tma_store` reads the MMA layout directly and produces garbage.
3. **O rescale (kt > 0) + normalize → cos 0.793** at n=256. Each round-trip compounds the layout mismatch error.
**Fix proven but not yet integrated:** The `epilogue_tmem_copy_and_partition` + `epilogue_smem_copy_and_partition` pattern from CUTLASS's `cutlass.utils.gemm.sm100` reads O from TMEM using the correct `get_tmem_load_op` layout and writes to SMEM using `get_smem_store_op`. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper `flat_divide` and `tma_partition` handling inside the kernel's warp-specific if blocks.
### Why the NO-OP Round-Trip "Fixes" It
The MMA writes O in C-fragment TMEM layout. `epilogue_tma_store` reads in `get_tmem_load_op` layout. Without a round-trip, these layouts are incompatible → garbage output (cos 0.465).
A NO-OP round-trip through hand-constructed atoms reads the data using the hand-constructed layout (which can read the C-fragment data) and writes it back using the hand-constructed layout. After the round-trip, the data is in the hand-constructed layout, which is close to (but not identical to) the `get_tmem_load_op` layout → 3% error (cos 0.973).
### The Proper Fix: correction_epilog Pattern
The CUTLASS FMHA reference uses a one-way trip for the final epilogue:
### Key Bug Fix: tOrP0 TMEM Column Offset (May 23)
The softmax warps store P at `tmem_p0_offset=32` FP32 columns (64 BF16 elements). PV MMA must read from the same offset. **`tOrP0` was missing this offset**, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with:
```python
if const_expr(self.tOrP0_offset > 0):
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
tOrP0 = tOrP
```
TMEM --get_tmem_load_op--> reg (normalize + FP32→BF16) --get_smem_store_op--> SMEM --TMA--> GMEM
```
This reads O using `get_tmem_load_op` (same layout as `epilogue_tma_store`) and writes directly to SMEM. No TMEM round-trip. No layout mismatch. The CUTLASS reference uses this pattern and gets correct results.
**Why we can't use it yet:** The TMA store from SMEM → GMEM requires `tma_partition` / `flat_divide` which hit CuTeDSL region isolation errors when called inside `if warp_idx < self.mma_warp_id` blocks. The `epilogue_tma_store` helper works because it's a regular Python function that inlines into the same MLIR region — but it always reads from TMEM, not SMEM.
**Possible solutions:**
1. Call `epilogue_tma_store` but inject the `1/row_sum` multiply into its pipeline (requires modifying the helper or replicating it inline with the scale)
2. Pre-compute TMA partitioning outside the `if` block and pass the partitioned tensors through the kernel interface
3. Use the experimental `cutlass.cute.experimental.epilogue_tma_store` API which has a cleaner structure
### Verified Facts
- **Raw PV output is perfect:** `epilogue_tma_store` with identity op gives cos 0.999998 at n=128
- **Softmax P values are correct:** Unnormalized P@V matches reference exactly (cos 0.999998)
- **Online softmax computation is correct:** row_max and row_sum tracking works
- **The ONLY issue is the TMEM round-trip for normalize/rescale**
- **Stage A/B with identity softmax: cos 0.999999** — the pipeline works, softmax is the only addition
Must use `const_expr` conditional (not Python `if`) because CuTeDSL compiles both branches, and `tOrP.iterator + 0` fails with MLIR type error.
### Architecture (6-warp, current)
@@ -404,6 +383,8 @@ Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
8. **CuTeDSL region isolation:** `flat_divide` and `tma_partition` can't be called inside `if warp_idx` blocks. Do partitioning outside `if` blocks or in regular (non-`@cute.kernel`) helper functions.
9. **`composition` vs `logical_divide`:** Both re-tile a tensor, but produce different layouts. The CUTLASS `correction_rescale` uses `composition`, `correction_epilog` uses `logical_divide`. The copy atoms must match the tensor layout they were created with.
10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching.
11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails).
12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000.
---
@@ -471,7 +452,7 @@ The SWA branch needs a causal mask within the window. Add `is_causal: bool` cons
Done when: prefill mode produces correct output with the causal mask applied to SWA.
**D5 — SWA + sink merge** (~2-3 days) ← the real new work
**D5 — SWA + sink merge** (~2-3 days) ← D5a+D5b DONE (May 23), D5c/D5d remaining
Per `dsv4/ops/decode_sparse.py`:
```
@@ -479,16 +460,17 @@ o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
```
Both branches must produce **un-normalized `o`** AND **`lse = log(row_sum) + row_max`**. Stage C currently normalizes by `1/row_sum` in the epilogue — that must change.
With un-normalized O (D5a): `o_unnorm = o_norm * exp(lse)`, so:
```
o = (o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa)
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
```
Three structural changes:
1. **TMEM grows:** two O accumulators + two row_max/row_sum per row. Roughly doubles TMEM column usage. Verify it fits.
2. **Q is loaded once, used twice.** Free win — largest input stays in SMEM.
3. **K and V have two sources.** Compressed K/V from contiguous BF16 input. SWA K/V from state cache — for now, dequantize SWA in a small prep kernel before FMHA, let FMHA see two contiguous BF16 sources. Optimize later.
**D5a DONE (May 23):** `normalize` flag added to FmhaKernel. When False, emits un-normalized O + LSE. LSE formula: `lse = ln(row_sum) + row_max * ln(2)` (row_max in scale_log2 domain, multiply by ln(2) to convert). LSE err=0.000000 verified.
Sub-steps:
- **5a:** Modify Stage C to emit un-normalized `o` + `lse` instead of normalized `o`. (Keep normalize as a flag so standalone tests still work.)
- **5b:** Run kernel twice externally (once with compressed_kv, once with swa_kv), merge results in Python. **End-to-end correctness without touching kernel structure.** Hours of work.
**D5b DONE (May 23):** Python SWA+sink merge works end-to-end at hd=64. Run FMHA twice (compressed KV + SWA KV, normalize=False), merge in Python. Merge cos 0.961, individual attention cos 0.963/0.960.
Sub-steps remaining:
- **5c:** Fuse the two passes into one kernel launch. Q stays in SMEM, two MMA loops sequentially.
- **5d:** Fuse the merge into the kernel epilogue.

View File

@@ -164,4 +164,25 @@ else:
3. Examine PV MMA tiler and SMEM layout generation
4. Consider alternative: fix TMEM layout generation instead
**Time spent:** ~45 minutes. Have working but incorrect SMEM-P.
## Progress Update (2026-05-23 21:30 UTC)
**TMEM-P path now works at hd=64 (cos 0.973).** The root cause of NaN/zeros was a missing TMEM column offset on `tOrP0` — PV MMA was reading from column 0 (where S is) instead of column 32 (where P is stored by softmax warps). Fixed with `const_expr` conditional.
**SMEM-P remains unsolved.** The `make_tiled_copy_C` approach gives rank mismatch. Manual coordinate mapping compiles but produces near-zero cosine (wrong addresses). The CUTLASS reference FMHA uses TMEM-P exclusively (12-warp layout with more TMEM budget). For our 6-warp layout, SMEM-P is needed for hd>64.
**Current D1 status:**
- hd=64 (TMEM-P): cos 0.973 ✅
- hd=256 (SMEM-P stub): FAIL (zeros)
- hd=512 (SMEM-P stub): FAIL (zeros)
**Workaround for hd>64:** The D5b milestone (Python SWA+sink merge) works at hd=64. SMEM-P for hd>64 is a production optimization, not a correctness blocker. The full DSV4 pipeline (CSA + HCA + SWA) can be tested at hd=64 with TMEM-P.
**Key discoveries:**
1. `make_tiled_copy_C(store_atom, qk_mma)` creates a copy that partitions threads by QK C-fragment layout, but the source and destination have incompatible ranks (4 vs 3). This is a fundamental layout incompatibility.
2. Manual coordinate mapping (`qk_to_pv_coord`) compiles and runs but produces wrong results. The mapping formula may be incorrect, or SMEM swizzle may interfere with tensor indexing.
3. The CUTLASS reference FMHA (12-warp) avoids SMEM-P entirely by using TMEM-P with more warps (more TMEM budget). A 12-warp layout would solve the SMEM-P problem architecturally.
**Recommended next steps for SMEM-P:**
1. Try 12-warp layout (like CUTLASS reference) to avoid SMEM-P entirely
2. OR: Use `make_tiled_copy` with `pv_mma` (not `qk_mma`) for the copy, since PV MMA knows the SMEM layout
3. OR: Implement a two-stage copy: QK C-fragment → intermediate buffer → PV A-operand SMEM