From dadfad8f895f714a2b3ef41d0b2c570eedcfde99 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 21:35:25 +0000 Subject: [PATCH] Docs: Update STAGE_D.md, README.md with hd=512 compilation blocker, lessons learned --- README.md | 16 ++++++----- STAGE_D.md | 78 +++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 68 insertions(+), 26 deletions(-) diff --git a/README.md b/README.md index 00a55272..c1358dee 100644 --- a/README.md +++ b/README.md @@ -138,14 +138,14 @@ Summary --- -## Status (May 24, 2026 — 04:30 UTC) +## Status (May 24, 2026 — 21:30 UTC) | Stage | Status | Description | |-------|--------|-------------| | A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM | | B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) | | C | ✅ COMPLETE | Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to `dsv4/kernels/attention/fmha.py` as `FmhaKernel`. | -| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM overflow. | +| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. | | D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) | | D3 | TODO | SWA sequence length mask (swa_lens per batch) | | D4 | TODO | Causal mask on SWA branch only | @@ -385,6 +385,10 @@ Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16)) 10. **Variables in CuTeDSL `if` blocks are NOT visible in other `if` blocks.** Even when the condition is a compile-time constant (`self.use_smem_p`), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined *unconditionally* before the first `if` that uses them. This applies across `if warp_idx == X` blocks, `for` loops, and nested branches. If a variable is set in `if not use_smem_p:` and read in another `if not use_smem_p:` inside a `for` loop inside an `if warp_idx < mma_warp_id:`, it won't be visible. Define all such variables before *any* branching. 11. **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use `const_expr` conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`. Cannot use `tOrP.iterator + 0` (MLIR OpResult + int fails). 12. **LSE formula: `lse = ln(row_sum) + row_max * ln(2)`.** `row_max` is in the scale_log2 domain (`max(S * scale * log2(e))`). Multiply by `ln(2)` to convert to natural log domain: `attn_max = row_max * ln(2)`. So `lse = ln(row_sum) + row_max * ln(2)`. Verified: LSE err=0.000000. +13. **CuTeDSL MLIR backend cannot handle complex pipeline loops.** The MLIR→PTX optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both Python `range()` (unrolled) and `cutlass.range(unroll=1)` (runtime) trigger 3+ hour compilation for hd=512. Consider raw CUDA C++ for complex kernels. Pre-compilation + cubin caching is a viable workaround if the optimizer eventually finishes. +14. **Guard dead code with `const_expr`.** CuTeDSL compiles BOTH branches of Python `if` statements. Use `const_expr(condition)` to eliminate dead code at compile time. Critical for: O rescale (only when n_kv_tiles>1), LSE (only when normalize=False), SMEM-P path (only when use_smem_p=True), k_sub path (only when n_k_sub_tiles>1). +15. **External k_sub merge is mathematically impossible.** k_sub segments are additive in LOGIT space (S = S_0 + S_1), not attention weight space. You cannot recover softmax(S_0+S_1)@V from softmax(S_0)@V and softmax(S_1)@V. The D5 merge formula works for different token sets (additive in weight space), NOT for partial dot products. In-kernel k_sub accumulation before softmax is the only correct approach. +16. **`pv_n_tile` reduction is the easiest SMEM knob.** At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. Cost: 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck, and this is simpler than SMEM overlap or Q tiling. --- @@ -418,15 +422,13 @@ The SWA branch is the only "irregular" thing: it reads from the state cache's ri ### Build Order -**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, in progress) +**D1 — Parameterize HEAD_DIM + SMEM-P** (~1 day, MOSTLY DONE) Currently hardcoded at 64. Promote to constructor arg, thread through `_setup`. Test at 64, then 512 (DSV4's real value). -**Two P staging paths:** -- **TMEM-P** (hd≤64): P stored to TMEM via register bridge. PV reads from TMEM. Proven at cos 0.973. -- **SMEM-P** (hd>64): P stored to SMEM via PV A-operand layout. PV reads from SMEM. Avoids QK↔PV TMEM layout mismatch at large hd. **Register→SMEM copy needs `make_tiled_copy_C(store_atom, qk_mma)` to partition threads by QK C-fragment.** The SMEM rendezvous pattern: softmax writes P to SMEM at logical (row, col) addresses using `p_smem_s` layout, MMA warp reads from same SMEM. Barrier in between. +hd≤256: ✅ DONE. cos 0.999998 at hd=64/128/256. Both TMEM-P and SMEM-P paths work. -Risk at HEAD_DIM=512: TMEM column budget. `_setup` already does `find_tmem_tensor_col_offset(tOtO)` dynamically. Verify the total fits in 512 TMEM columns. If not, reduce `kv_stage` from 2 to 1 (lose K/V double-buffering) before sacrificing math. +hd=512: ❌ BLOCKED. SMEM budget fixed (192KB, fits 232KB limit). Kernel structurally correct (tracer 0.8s). But CuTeDSL's MLIR→PTX backend optimizer hangs for 3+ hours when compiling the k_sub loop. External k_sub merge is mathematically impossible (k_sub segments additive in logit space, not weight space). Need either: (a) pre-compile offline + cache cubin, (b) add no-softmax mode for S accumulation in Python, or (c) write hd=512 path in raw CUDA C++. Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle. diff --git a/STAGE_D.md b/STAGE_D.md index 1aa4ef9a..e90f3c7a 100644 --- a/STAGE_D.md +++ b/STAGE_D.md @@ -31,20 +31,21 @@ --- -## Current Status (2026-05-24) +## Current Status (2026-05-24, 21:30 UTC) ### ✅ WORKING -| hd | n=128 cos | LSE err | Path | -|---:|----------:|--------:|------| -| 64 | 0.999998 | 0.000000 | TMEM-P | -| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | -| 256 | 0.999998 | 0.000000 | TMEM-P | +| hd | n=128 cos | LSE err | Path | SMEM | +|---:|----------:|--------:|------|-----:| +| 64 | 0.999998 | 0.000000 | TMEM-P | 128KB | +| 128 | 0.999997 | 0.000000 | TMEM-P / SMEM-P | 128KB | +| 256 | 0.999998 | 0.000000 | TMEM-P | 224KB | ### ❌ KNOWN ISSUES -- **hd=512:** SMEM overflow (344KB > 232KB). sQ(128KB) + sK(128KB) + sV(64KB) too large. Needs SMEM tiling or buffer overlap. -- **O rescale (kt>0):** Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. +- **hd=512: MLIR compilation hangs.** SMEM budget fixed (192KB ✅), kernel structure correct (tracer 0.8s), but MLIR→PTX backend optimizer cannot process the IR in reasonable time (>3 hours). Both `range()` unrolled and `cutlass.range(unroll=1)` runtime loops trigger this. This is a CuTeDSL/MLIR toolchain limitation. +- **External k_sub merge doesn't work.** k_sub segments are additive in logit space (S = S_0 + S_1), not attention weight space. The D5 merge formula does not apply. In-kernel k_sub accumulation is the only correct approach. +- **O rescale (kt>0):** Uses hand-constructed TMEM atoms. May corrupt data for n>128 (multi-KV-tile). At n=128 (1 KV tile, kt=0), no rescale needed. Guarded with `const_expr(n_kv_tiles > 1)`. - **Kernel always outputs un-normalized O + LSE.** No in-kernel normalization (eliminates TMEM round-trip error). External normalization: `O_norm = O_unnorm / row_sum`. --- @@ -116,16 +117,42 @@ This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd --- +## Lessons Learned (2026-05-24) + +### 1. CuTeDSL MLIR Backend Cannot Handle Complex Pipeline Loops +The MLIR→PTX backend optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both unrolled (Python `range`) and runtime (`cutlass.range unroll=1`) loops trigger this. The Python tracer is fast (0.8s) because it just generates IR. The MLIR optimizer then chews on that IR for hours. **Workaround:** keep pipeline loops as simple as possible. Consider raw CUDA C++ for complex kernels. + +### 2. External k_sub Merge is Mathematically Impossible +You CANNOT merge the outputs of two attention calls that compute softmax(Q_k0 @ K_k0^T)@V and softmax(Q_k1 @ K_k1^T)@V into softmax(Q @ K^T)@V. The k_sub segments are additive in LOGIT space (S = S_0 + S_1), but softmax is nonlinear. The D5 merge formula works because sparse and SWA attend over DIFFERENT token sets (additive in weight space). k_sub attends over the SAME tokens with PARTIAL dot products. These are fundamentally different operations. **The only correct approach is in-kernel accumulation (S_0 + S_1 before softmax).** + +### 3. pv_n_tile Reduction is the Easiest SMEM Knob +At hd>256, reducing `pv_n_tile` from 256 to 128 shrinks sV and sC by 2× each. The cost is 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck. This is simpler than SMEM overlap (which requires CuTeDSL SmemAllocator changes) or Q tiling (which adds pipeline complexity). + +### 4. Guard Dead Code with const_expr +CuTeDSL compiles BOTH branches of Python `if` statements, generating IR for code that will never execute at a given head_dim. Use `const_expr(condition)` to eliminate dead code at compile time. This is critical for: +- O rescale code (only needed when n_kv_tiles > 1) +- LSE computation (only needed when normalize=False) +- SMEM-P path (only needed when use_smem_p=True) + +### 5. Don't Mix Python Loops and CuTeDSL Pipeline Operations +Python `for` loops unroll at trace time, creating N copies of the loop body in the IR. For pipeline acquire/release + TMA copy + GEMM, each copy is substantial. `cutlass.range(unroll=1)` creates a runtime loop with one copy of the body. **For pipeline operations, prefer `cutlass.range(unroll=1)` to reduce IR size**, even though the MLIR optimizer may still struggle with it. + +### 6. The k_tile Parameter is the Key to hd=512 +At hd=512, the kernel splits Q and K into sub-tiles of size `k_tile=256` along the head_dim. Each sub-tile is loaded via TMA, processed by MMA, and accumulated. `n_k_sub_tiles = head_dim // k_tile = 2`. The k_tile parameter controls the sub-tile size and the number of iterations. **k_tile must be ≤ 256** (MMA instruction K-dim limit) and must evenly divide head_dim. + +--- + ## SMEM Budget at Various hd -| hd | sQ | sK (kv_stage=1) | sV (kv_stage=1) | sP (SMEM-P) | sC | Total | Limit | Status | +| hd | sQ | sK (kv_stage=1) | sV (pv_n_tile) | sP (SMEM-P) | sC | Total | Limit | Status | |---:|----:|----:|----:|----:|----:|------:|------:|--------| -| 64 | 32KB | 32KB | 32KB | 32KB | 32KB | 160KB | 232KB | ✅ | -| 128 | 32KB | 32KB | 32KB | 32KB | 32KB | 160KB | 232KB | ✅ | -| 256 | 64KB | 64KB | 64KB | 0* | 32KB | 224KB | 232KB | ✅ | -| 512 | 128KB | 128KB | 64KB | 0* | 32KB | 352KB | 232KB | ❌ | +| 64 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ | +| 128 | 32KB | 32KB | 32KB (256) | — | 32KB | 128KB | 232KB | ✅ | +| 256 | 64KB | 64KB | 64KB (256) | 0* | 32KB | 224KB | 232KB | ✅ | +| 512 | 64KB | 64KB | 32KB (128) | 0* | 32KB | 192KB | 232KB | ⚠️ Fits but MLIR hangs | *TMEM-P path: sP allocation skipped (const_expr conditional) +pv_n_tile shown in parens; hd>256 uses pv_n_tile=128 (4 PV GEMM passes) to fit SMEM --- @@ -143,14 +170,27 @@ This caused the QK GEMM to only compute 64 of 128 (or 256, 512) dimensions at hd ## Build Order (Remaining) -### D1.4 — hd=512 SMEM Budget ⚡ CURRENT +### D1.4 — hd=512 ⚡ CURRENT (BLOCKED) -hd=512 needs sQ(128KB) + sK(128KB) + sV(64KB) = 320KB. Must reduce to fit 232KB. +**Problem:** hd=512 exceeds the MMA instruction's max K-dim (256). Must split Q and K into 2 sub-tiles along head_dim (k_tile=256, n_k_sub_tiles=2). The QK dot product is S = Q_k0 @ K_k0^T + Q_k1 @ K_k1^T (additive in logit space). -Options: -1. **Tile Q along head_dim:** Process Q in chunks of 256. Two Q sub-tiles per kernel. -2. **SMEM buffer overlap:** sQ and sK/sV used at different times. After Q is consumed by MMA, reuse sQ's SMEM for K/V. -3. **Split the GEMM K dimension:** Process K in sub-tiles (K=256 then K=256-512). Each sub-tile fits SMEM. +**SMEM budget: SOLVED.** pv_n_tile=128 for hd>256 reduces sV from 64KB→32KB, sC from 64KB→32KB. Total 192KB ✅. + +**Compilation: BLOCKED.** The CuTeDSL MLIR→PTX backend optimizer cannot compile the hd=512 kernel in reasonable time. Both Python `range()` (unrolled IR) and `cutlass.range(unroll=1)` (runtime loop) produce IR that the optimizer chews on for 3+ hours without finishing. The Python tracer completes in 0.8s — the kernel is structurally correct. This is a toolchain limitation. + +**External merge: IMPOSSIBLE.** The D5 online softmax merge formula assumes separate attention distributions over different token sets (additive in weight space). k_sub segments are additive in LOGIT space (S = S_0 + S_1), not weight space. You cannot recover softmax(S_0 + S_1)@V from softmax(S_0)@V and softmax(S_1)@V. In-kernel accumulation before softmax is the only correct approach. + +**Bug fixes applied along the way:** +- LSE type mismatch (BF16 vs FP32 when normalize=True) → guarded with `const_expr(not self.normalize)` +- O rescale IR explosion at n=128 → guarded with `const_expr(n_kv_tiles > 1)` +- k_sub tracer IR explosion → replaced hardcoded `if k_sub==0/1` with Python `range()` loop +- External merge test (cos 0.617) → confirmed mathematically impossible, deleted approach + +**Possible paths forward (priority order):** +1. **Pre-compile hd=512 kernel offline.** Accept 1-2 hour compilation during build. Cache the cubin. This works if the MLIR optimizer eventually finishes (it might just be slow, not stuck — but 3+ hours is excessive even for pre-compilation). +2. **Add no-softmax mode to the kernel.** Output raw S (QK scores) without softmax. Call twice for k_sub=0 and k_sub=1. Accumulate S_0+S_1 in Python. Apply softmax once. This requires modifying the softmax warp to optionally skip normalization and output S to GMEM instead of P to TMEM/SMEM. +3. **Write hd=512 kernel in CUTLASS C++.** Bypass CuTeDSL's MLIR backend entirely. Use raw CUTLASS C++ with tcgen05 MMA intrinsics. More work but compilation is fast (seconds). +4. **Report CuTeDSL MLIR optimizer bug.** The optimizer should handle this IR in reasonable time. File an issue with NVIDIA. ### D2 — Multi-Query Grid with Head Packing