Docs: Update STAGE_D.md, README.md with hd=512 compilation blocker, lessons learned

This commit is contained in:
2026-05-24 21:35:25 +00:00
parent a5fef69363
commit dadfad8f89
2 changed files with 68 additions and 26 deletions

View File

@@ -138,14 +138,14 @@ Summary
---
## Status (May 24, 2026 — 04:30 UTC)
## Status (May 24, 2026 — 21:30 UTC)
| Stage | Status | Description |
|-------|--------|-------------|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ✅ COMPLETE | Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. |
| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM overflow. |
| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. |
| D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) |
| D3 | TODO | SWA sequence length mask (swa_lens per batch) |
| D4 | TODO | Causal mask on SWA branch only |
@@ -385,6 +385,10 @@ Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching.
11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails).
12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000.
13. **CuTeDSL MLIR backend cannot handle complex pipeline loops.** The MLIR→PTX optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both Python `range()` (unrolled) and `cutlass.range(unroll=1)` (runtime) trigger 3+ hour compilation for hd=512. Consider raw CUDA C++ for complex kernels. Pre-compilation + cubin caching is a viable workaround if the optimizer eventually finishes.
14. **Guard dead code with `const_expr`.** CuTeDSL compiles BOTH branches of Python `if` statements. Use `const_expr(condition)` to eliminate dead code at compile time. Critical for: O rescale (only when n_kv_tiles>1), LSE (only when normalize=False), SMEM-P path (only when use_smem_p=True), k_sub path (only when n_k_sub_tiles>1).
15. **External k_sub merge is mathematically impossible.** k_sub segments are additive in LOGIT space (S = S_0 + S_1), not attention weight space. You cannot recover softmax(S_0+S_1)@V from softmax(S_0)@V and softmax(S_1)@V. The D5 merge formula works for different token sets (additive in weight space), NOT for partial dot products. In-kernel k_sub accumulation before softmax is the only correct approach.
16. **`pv_n_tile` reduction is the easiest SMEM knob.** At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. Cost: 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck, and this is simpler than SMEM overlap or Q tiling.
---
@@ -418,15 +422,13 @@ The SWA branch is the only "irregular" thing: it reads from the state cache's ri
### Build Order
**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, in progress)
**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, MOSTLY DONE)
Currently hardcoded at 64. Promote to constructor arg, thread through `_setup`. Test at 64, then 512 (DSV4's real value).
**Two P staging paths:**
- **TMEM-P** (hd≤64): P stored to TMEM via register bridge. PV reads from TMEM. Proven at cos 0.973.
- **SMEM-P** (hd>64): P stored to SMEM via PV A-operand layout. PV reads from SMEM. Avoids QK↔PV TMEM layout mismatch at large hd. **Register→SMEM copy needs `make_tiled_copy_C(store_atom, qk_mma)` to partition threads by QK C-fragment.** The SMEM rendezvous pattern: softmax writes P to SMEM at logical (row, col) addresses using `p_smem_s` layout, MMA warp reads from same SMEM. Barrier in between.
hd≤256: ✅ DONE. cos 0.999998 at hd=64/128/256. Both TMEM-P and SMEM-P paths work.
Risk at HEAD_DIM=512: TMEM column budget. `_setup` already does `find_tmem_tensor_col_offset(tOtO)` dynamically. Verify the total fits in 512 TMEM columns. If not, reduce `kv_stage` from 2 to 1 (lose K/V double-buffering) before sacrificing math.
hd=512: ❌ BLOCKED. SMEM budget fixed (192KB, fits 232KB limit). Kernel structurally correct (tracer 0.8s). But CuTeDSL's MLIR→PTX backend optimizer hangs for 3+ hours when compiling the k_sub loop. External k_sub merge is mathematically impossible (k_sub segments additive in logit space, not weight space). Need either: (a) pre-compile offline + cache cubin, (b) add no-softmax mode for S accumulation in Python, or (c) write hd=512 path in raw CUDA C++.
Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.

View File

@@ -31,20 +31,21 @@
---
## Current Status (2026-05-24)
## Current Status (2026-05-24, 21:30 UTC)
### ✅ WORKING
| hd | n=128 cos | LSE err | Path |
|---:|----------:|--------:|------|
| 64 | 0.999998 | 0.000000 | TMEM-P |
| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P |
| 256 | 0.999998 | 0.000000 | TMEM-P |
| hd | n=128 cos | LSE err | Path | SMEM |
|---:|----------:|--------:|------|-----:|
| 64 | 0.999998 | 0.000000 | TMEM-P | 128KB |
| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | 128KB |
| 256 | 0.999998 | 0.000000 | TMEM-P | 224KB |
### ❌ KNOWN ISSUES
- **hd=512:** SMEM overflow (344KB > 232KB). sQ(128KB) + sK(128KB) + sV(64KB) too large. Needs SMEM tiling or buffer overlap.
- **O rescale (kt>0):** Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed.
- **hd=512: MLIR compilation hangs.** SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). Both `range()` unrolled and `cutlass.range(unroll=1)` runtime loops trigger this. This is a CuTeDSL/MLIR toolchain limitation.
- **External k_sub merge doesn't work.** k_sub segments are additive in logit space (S = S_0 + S_1), not attention weight space. The D5 merge formula does not apply. In-kernel k_sub accumulation is the only correct approach.
- **O rescale (kt>0):** Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. Guarded with `const_expr(n_kv_tiles > 1)`.
- **Kernel always outputs un-normalized O + LSE.** No in-kernel normalization (eliminates TMEM round-trip error). External normalization: `O_norm = O_unnorm / row_sum`.
---
@@ -116,16 +117,42 @@ This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd
---
## Lessons Learned (2026-05-24)
### 1. CuTeDSL MLIR Backend Cannot Handle Complex Pipeline Loops
The MLIR→PTX backend optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger this. The Python tracer is fast (0.8s) because it just generates IR. The MLIR optimizer then chews on that IR for hours. **Workaround:** keep pipeline loops as simple as possible. Consider raw CUDA C++ for complex kernels.
### 2. External k_sub Merge is Mathematically Impossible
You CANNOT merge the outputs of two attention calls that compute softmax(Q_k0 @ K_k0^T)@V and softmax(Q_k1 @ K_k1^T)@V into softmax(Q @ K^T)@V. The k_sub segments are additive in LOGIT space (S = S_0 + S_1), but softmax is nonlinear. The D5 merge formula works because sparse and SWA attend over DIFFERENT token sets (additive in weight space). k_sub attends over the SAME tokens with PARTIAL dot products. These are fundamentally different operations. **The only correct approach is in-kernel accumulation (S_0 + S_1 before softmax).**
### 3. pv_n_tile Reduction is the Easiest SMEM Knob
At hd>256, reducing `pv_n_tile` from 256 to 128 shrinks sV and sC by 2× each. The cost is 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck. This is simpler than SMEM overlap (which requires CuTeDSL SmemAllocator changes) or Q tiling (which adds pipeline complexity).
### 4. Guard Dead Code with const_expr
CuTeDSL compiles BOTH branches of Python `if` statements, generating IR for code that will never execute at a given head_dim. Use `const_expr(condition)` to eliminate dead code at compile time. This is critical for:
- O rescale code (only needed when n_kv_tiles > 1)
- LSE computation (only needed when normalize=False)
- SMEM-P path (only needed when use_smem_p=True)
### 5. Don't Mix Python Loops and CuTeDSL Pipeline Operations
Python `for` loops unroll at trace time, creating N copies of the loop body in the IR. For pipeline acquire/release + TMA copy + GEMM, each copy is substantial. `cutlass.range(unroll=1)` creates a runtime loop with one copy of the body. **For pipeline operations, prefer `cutlass.range(unroll=1)` to reduce IR size**, even though the MLIR optimizer may still struggle with it.
### 6. The k_tile Parameter is the Key to hd=512
At hd=512, the kernel splits Q and K into sub-tiles of size `k_tile=256` along the head_dim. Each sub-tile is loaded via TMA, processed by MMA, and accumulated. `n_k_sub_tiles = head_dim // k_tile = 2`. The k_tile parameter controls the sub-tile size and the number of iterations. **k_tile must be ≤ 256** (MMA instruction K-dim limit) and must evenly divide head_dim.
---
## SMEM Budget at Various hd
| hd | sQ | sK (kv_stage=1) | sV (kv_stage=1) | sP (SMEM-P) | sC | Total | Limit | Status |
| hd | sQ | sK (kv_stage=1) | sV (pv_n_tile) | sP (SMEM-P) | sC | Total | Limit | Status |
|---:|----:|----:|----:|----:|----:|------:|------:|--------|
| 64 | 32KB | 32KB | 32KB | 32KB | 32KB | 160KB | 232KB | ✅ |
| 128 | 32KB | 32KB | 32KB | 32KB | 32KB | 160KB | 232KB | ✅ |
| 256 | 64KB | 64KB | 64KB | 0* | 32KB | 224KB | 232KB | ✅ |
| 512 | 128KB | 128KB | 64KB | 0* | 32KB | 352KB | 232KB | |
| 64 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ |
| 128 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ |
| 256 | 64KB | 64KB | 64KB (256) | 0* | 32KB | 224KB | 232KB | ✅ |
| 512 | 64KB | 64KB | 32KB (128) | 0* | 32KB | 192KB | 232KB | ⚠️ Fits but MLIR hangs |
*TMEM-P path: sP allocation skipped (const_expr conditional)
pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit SMEM
---
@@ -143,14 +170,27 @@ This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd
## Build Order (Remaining)
### D1.4 — hd=512 SMEM Budget ⚡ CURRENT
### D1.4 — hd=512 ⚡ CURRENT (BLOCKED)
hd=512 needs sQ(128KB) + sK(128KB) + sV(64KB) = 320KB. Must reduce to fit 232KB.
**Problem:** hd=512 exceeds the MMA instruction's max K-dim (256). Must split Q and K into 2 sub-tiles along head_dim (k_tile=256, n_k_sub_tiles=2). The QK dot product is S = Q_k0 @ K_k0^T + Q_k1 @ K_k1^T (additive in logit space).
Options:
1. **Tile Q along head_dim:** Process Q in chunks of 256. Two Q sub-tiles per kernel.
2. **SMEM buffer overlap:** sQ and sK/sV used at different times. After Q is consumed by MMA, reuse sQ's SMEM for K/V.
3. **Split the GEMM K dimension:** Process K in sub-tiles (K=256 then K=256-512). Each sub-tile fits SMEM.
**SMEM budget: SOLVED.** pv_n_tile=128 for hd>256 reduces sV from 64KB→32KB, sC from 64KB→32KB. Total 192KB ✅.
**Compilation: BLOCKED.** The CuTeDSL MLIR→PTX backend optimizer cannot compile the hd=512 kernel in reasonable time. Both Python `range()` (unrolled IR) and `cutlass.range(unroll=1)` (runtime loop) produce IR that the optimizer chews on for 3+ hours without finishing. The Python tracer completes in 0.8s — the kernel is structurally correct. This is a toolchain limitation.
**External merge: IMPOSSIBLE.** The D5 online softmax merge formula assumes separate attention distributions over different token sets (additive in weight space). k_sub segments are additive in LOGIT space (S = S_0 + S_1), not weight space. You cannot recover softmax(S_0 + S_1)@V from softmax(S_0)@V and softmax(S_1)@V. In-kernel accumulation before softmax is the only correct approach.
**Bug fixes applied along the way:**
- LSE type mismatch (BF16 vs FP32 when normalize=True) → guarded with `const_expr(not self.normalize)`
- O rescale IR explosion at n=128 → guarded with `const_expr(n_kv_tiles > 1)`
- k_sub tracer IR explosion → replaced hardcoded `if k_sub==0/1` with Python `range()` loop
- External merge test (cos 0.617) → confirmed mathematically impossible, deleted approach
**Possible paths forward (priority order):**
1. **Pre-compile hd=512 kernel offline.** Accept 1-2 hour compilation during build. Cache the cubin. This works if the MLIR optimizer eventually finishes (it might just be slow, not stuck — but 3+ hours is excessive even for pre-compilation).
2. **Add no-softmax mode to the kernel.** Output raw S (QK scores) without softmax. Call twice for k_sub=0 and k_sub=1. Accumulate S_0+S_1 in Python. Apply softmax once. This requires modifying the softmax warp to optionally skip normalization and output S to GMEM instead of P to TMEM/SMEM.
3. **Write hd=512 kernel in CUTLASS C++.** Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds).
4. **Report CuTeDSL MLIR optimizer bug.** The optimizer should handle this IR in reasonable time. File an issue with NVIDIA.
### D2 — Multi-Query Grid with Head Packing