From a7fd2761dfef53a2c189c2bd59052abeab73ad71 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 21 May 2026 05:17:12 +0000 Subject: [PATCH] =?UTF-8?q?README:=20Bug=204=20root=20cause=20=E2=80=94=20?= =?UTF-8?q?TMEM=20layout=20mismatch=20(128,64)=20PV=20A-fragment=20vs=20so?= =?UTF-8?q?ftmax=20P=20write?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - (128,64) PV MMA A-fragment has N_MMA=64, reads P with wrong stride - Softmax writes P with QK C-fragment layout (N_MMA=128) - O[m,d] ≈ P[m,2d] — every other column effect confirmed - All-ones and single-element V pass (uniform/sparse data hides mismatch) - epi_tile must use PV cta_tile (partial fix: 0.01 → 0.876) - Added footguns #9 (TMEM alias N_MMA match) and #10 (epi_tile) - Added diagnostic test results to test table --- README.md | 94 +++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 60 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index 6316d0ef..72f31fe5 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use `cutlass.cute` (CuTeDSL) with Blackwell tensor cores. -## Status (May 21, 2026 — 04:35 UTC) +## Status (May 21, 2026 — 05:15 UTC) ### ✅ Stage A: Bare Q@K^T via tcgen05.mma → TMEM → GMEM — COMPLETE @@ -12,19 +12,52 @@ CuTeDSL kernels for DeepSeek-V4 (Blackwell B200, SM100). All kernels use `cutlas ### 🔨 Stage B: Two MMAs + Identity Softmax — IN PROGRESS **Pipeline deadlock: FIXED. Kernel runs without deadlock.** -**Bug 1 (V MN-major): Fix applied.** -**Bug 2 (softmax packing): Confirmed correct (V=I test: cosine 1.0).** -**Bug 3 (ACCUMULATE): Fix applied.** -**Bug 4 (non-square PV): PV works for (128,128) output, broken for (128,64) output.** +**Bug 1 (V MN-major): ✅ Fix applied.** +**Bug 2 (softmax packing): ✅ Confirmed correct (V=I test: cosine 1.0).** +**Bug 3 (ACCUMULATE): ✅ Fix applied.** +**Bug 4 (non-square PV): 🔨 ROOT CAUSE IDENTIFIED — TMEM layout mismatch.** -#### Bug 1: V B-Operand Must Be MN-Major — ✅ FIX APPLIED +#### Bug 4 (CURRENT): PV MMA Broken for (128,64) Output — ROOT CAUSE IDENTIFIED + +**Root Cause: The (128,64) PV MMA's A-fragment reads P from TMEM with a different layout than the softmax packing writes it.** + +The softmax packing writes P using the **QK C-fragment layout** (MMA atom = (128,128,16), N_MMA=128). The PV MMA reads P using its **A-fragment layout** (MMA atom = (128,64,16), N_MMA=64). These two layouts produce different physical TMEM addresses for the same logical (m,k) coordinate. + +**Evidence:** +- Truncated identity V (64,128) MN-major: O[m,d] ≈ P[m, 2d] — the MMA reads every other column of P +- All-ones V: cosine 0.999999 ✅ (uniform data hides the layout mismatch) +- Single-element V: cosine 1.0 ✅ (sparse data also hides it) +- (128,128) PV with same softmax packing: cosine 0.999999 ✅ (N_MMA=128 matches QK, no mismatch) + +**C++ TMEM Fragment Layout (from mma_traits_sm100.hpp):** +```cpp +// For M_MMA = 128, N_MMA varies with the MMA atom's N dimension +Layout tmem_atom = Layout>, + Stride< _1, _128>>{}; +``` +- QK C-fragment: N_MMA=128 → 128 TMEM columns, stride 128 +- PV A-fragment (128,64): N_MMA=64 → 64 TMEM columns, stride 128 + +When the softmax packing writes P at `tmem_p0_offset` using the QK C-fragment layout (N_MMA=128), P's (m,k) elements land at TMEM address `m + 128*k`. But the PV A-fragment (N_MMA=64) reads the same TMEM region as if P were stored with N_MMA=64, so it interprets the data with stride 64 instead of 128, causing the every-other-column effect (O[m,d] ≈ P[m, 2d]). + +**Fix (not yet applied): The softmax packing must write P using the PV MMA's A-fragment layout, not the QK C-fragment layout.** FMHA does this correctly because its softmax writes P using a composition that matches the PV A-fragment — the `tStS_P` layout is derived from `tStS.layout` (QK C-fragment) but the TMEM store uses a C-fragment composition that's based on the PV MMA's tiling. The key is that FMHA's `tilePlikeFP32` computation adapts the packing width to match the PV output N. + +**Additional fix: `epi_tile` must be computed from PV cta_tile, not QK cta_tile.** Using QK's cta_tile for the epilogue produces `epi_tile=(128,128)` which is wrong for a (128,64) output. Computing from PV's cta_tile gives `epi_tile=(128:1, 32:1)`. This fix alone improved cosine from 0.01 to 0.876, but the TMEM layout mismatch remains. + +**V SMEM Layouts (confirmed correct):** +- `PV(128,64) V SMEM: outer=((64,16),1,8,1):((1,64),0,1024,0), inner=S<3,4,3>` +- `PV(128,128) V SMEM: outer=(((64,2),16),1,8,1):(((1,8192),64),0,1024,0), inner=S<3,4,3>` + +--- + +### Bug 1: V B-Operand Must Be MN-Major — ✅ FIX APPLIED V must be shaped (head_dim, seq) = (64, 128) with strides (1, 64) — MN-major. PV MMA uses `v_major` (OperandMajorMode.MN) instead of `b_major` (K). V must use `as_strided` — default PyTorch (64,128) gives strides (128,1) which is K-major. -#### Bug 2 (Packing): C-Fragment Composition Store — ✅ CONFIRMED CORRECT +### Bug 2: C-Fragment Composition Store — ✅ CONFIRMED CORRECT FP32→BF16 packing via C-fragment composition store (FMHA pattern) is correct. Proven by V=I test (cosine 1.0) and random V 128x128 test (cosine 0.999999). @@ -32,36 +65,17 @@ Proven by V=I test (cosine 1.0) and random V 128x128 test (cosine 0.999999). ⛔ **FOOTGUN**: `St32x32bOp` MUST use Float32, NOT BFloat16. ⚠️ The recast view for P packing uses the LOAD layout (128 BF16 elements), not the store composition shape. -#### Bug 3 (ACCUMULATE): First PV Must Use ACCUMULATE=False — ✅ FIX APPLIED +### Bug 3: First PV Must Use ACCUMULATE=False — ✅ FIX APPLIED If ACCUMULATE=True on the first PV, `O = P@V + old_O` adds uninitialized TMEM. Always ACCUMULATE=False for first PV, then True for subsequent tiles. -#### Bug 4 (CURRENT): PV MMA Broken for Non-Square Output — 🔨 ROOT CAUSE UNKNOWN +--- -**What works:** -- PV with (128,128) output, V=I: cosine 1.0 ✅ -- PV with (128,128) output, random V: cosine 0.999999 ✅ - -**What doesn't work:** -- PV with (128,64) output, V MN-major (64,128): cosine ~0.01 ❌ - -**Possible causes:** -1. `make_trivial_tiled_mma` with (128,64) produces different A-fragment layout — alias with softmax P may break -2. V TMA load wrong for (128,64) PV — SMEM layout, TMA descriptor, or partitioning incorrect -3. Epilogue/gC mismatch — output c is (128,64) but epilogue may write (128,128) tile -4. PV mma_tiler_mn doesn't affect the MMA atom (which is always (128,128,16)) - -**Diagnostic findings:** -- Pointer arithmetic correct: softmax P and PV A-fragment address same TMEM location -- Layout aliasing correct: C-fragment composition and A-fragment produce same physical addresses -- Pipeline ordering correct: softmax completes before PV starts -- Softmax packing correct: proven by V=I test - -### 🔨 Stage C: Online Softmax — AFTER B +## 🔨 Stage C: Online Softmax — AFTER B Per the pseudocode: epilogue warps compute per-row tile_max, rescale, exp, store P back to TMEM. -### 🔨 Stage D: FP8 Paged KV Gather — AFTER C +## 🔨 Stage D: FP8 Paged KV Gather — AFTER C Replace BF16 TMA load with FP8 paged KV gather + per-position dequant. @@ -103,7 +117,7 @@ FMHA requires `v_major_mode == OperandMajorMode.MN`. Passing K's K-major mode fo ### 3. CuTe Nested Layout Modes Flatten Sequentially -A layout like `((128,16),1,(4,2)):((65536,1),0,(16,64))` looks "non-sequential" but flattens to `addr = m*65536 + k` when k = k0 + 16*k1 + 64*k2 (CuTe row-major order). Do NOT assume nested modes imply non-sequential physical addressing. The C-fragment composition and A-fragment alias the same TMEM columns. +A layout like `((128,16),1,(4,2)):((65536,1),0,(16,64))` looks "non-sequential" but flattens to `addr = m*65536 + k` when k = k0 + 16*k1 + 64*k2 (CuTe row-major order). Do NOT assume nested modes imply non-sequential physical addressing. The C-fragment composition and A-fragment alias the same TMEM columns — BUT ONLY WHEN N_MMA MATCHES (i.e., (128,128) PV). For (128,64) PV, N_MMA=64 and the alias breaks. ### 4. PipelineUmmaAsync Consumer Group = Thread Count, NOT Warp Count @@ -137,6 +151,14 @@ tOrP0 = cute.make_tensor(tOrP.iterator + p_offset, tOrP.layout) ``` Both must address the same physical TMEM column. The 2× scaling accounts for FP32→BF16 element size difference. +### 9. C-Fragment → A-Fragment TMEM Alias Only Works When N_MMA Matches + +The softmax packing writes P using the QK C-fragment layout. The PV A-fragment reads P. These alias correctly ONLY when both MMA atoms have the same N_MMA (i.e., both (128,128,16) → N_MMA=128). When the PV MMA uses (128,64,16) → N_MMA=64, the A-fragment has a different TMEM stride and reads garbage. **The softmax packing must be adapted to write P in the PV A-fragment's layout.** + +### 10. epi_tile Must Match PV Output Shape, Not QK + +`compute_epilogue_tile_shape` must use PV's `cta_tile_shape_mnk`, not QK's. Also, `self.cta_tile_shape_mnk` must be set to PV's cta tile before calling `epilogue_tma_store` (it reads `gemm_kernel.cta_tile_shape_mnk` internally). FMHA sets `self.epi_tile = self.pv_mma_tiler[:2]` directly. + --- ## Architecture: Per-Tile Flow @@ -168,10 +190,14 @@ After all tiles: epilogue warps tcgen05.ld tmem_output, divide by row_sum, cast | `test_mma_si_only.py` | Q@K^T + mma_si pipeline (no PV) | 0.999999 | ✅ PASS | | `test_softmax_only.py` | Q@K^T + softmax packing, output S | 0.52 | ❌ S overwritten by P (expected) | | `test_mma_si_pv.py` | Q@K^T + softmax + P@V (V MN-major, 128x64) | 0.01 | ❌ PV output garbage | -| `test_pv_diag.py` | Q@K^T + softmax + P@V (V=I/random, 128x128) | 1.0 / 0.999999 | ✅ PASS | +| `test_pv_diag.py` | Q@K^T + softmax + P@V (V=I 128x128) | 1.0 | ✅ PASS | +| `test_pv_diag.py` | Q@K^T + softmax + P@V (random V 128x128) | 0.999999 | ✅ PASS | +| `test_diag_v_truncid.py` | Q@K^T + softmax + P@V (trunc identity 64x128, epi from PV) | 0.02 | ❌ O[m,d]≈P[m,2d] — TMEM alias mismatch | +| `test_diag_v_ones.py` | All-ones V (64x128) | 0.999999 | ✅ uniform data hides mismatch | +| `test_diag_v_ones.py` | Single-element V (64x128) | 1.0 | ✅ sparse data hides mismatch | +| `test_diag_layout.py` | (128,64) PV with epi from PV cta_tile | 0.876 | ❌ partial fix — epi correct, TMEM alias still broken | +| `test_diag_smem_layout.py` | Print V SMEM layouts for (128,64) vs (128,128) | N/A | ℹ️ layouts confirmed correct | | `test_layout_compare.py` | Print TMEM layouts for QK S and PV A-fragment | N/A | ℹ️ layout inspection | -| `test_stage_b_v7.py` | Q@K^T + C-fragment softmax (V=K, wrong major) | -0.02 | ❌ wrong major + P packing | -| `test_stage_b_v20.py` | Q@K^T + softmax (V=K, PipelineTmaStore bug) | N/A | ❌ compile error | ---