From df34cae9c66a1dad7c4d81b2cd9912b34bfc9db8 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 11:41:19 +0000 Subject: [PATCH] =?UTF-8?q?UMMA=20QK=20GEMM=20WORKING!=20Update=20docs=20?= =?UTF-8?q?=E2=80=94=204x=20was=20scale=20factor,=20not=20bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Major milestone: UMMA QK GEMM produces correct attention scores at HD=16! - MMA computes raw dot product; apply 1/sqrt(HD) scaling manually - tcgen05.fence::after_thread_sync for MMA→TMEM fence - 32x32b.x8 TMEM reads for Layout D output - 4 warps (128 threads) required for M=128 - Next: HD=64 multi-K-tile, PV GEMM, full FMHA pipeline --- CURRENT_ISSUE.md | 73 +++++++++++++++++++----------------------------- 1 file changed, 29 insertions(+), 44 deletions(-) diff --git a/CURRENT_ISSUE.md b/CURRENT_ISSUE.md index 9c994b56..abed5635 100644 --- a/CURRENT_ISSUE.md +++ b/CURRENT_ISSUE.md @@ -1,52 +1,37 @@ -# CURRENT ISSUE: UMMA QK GEMM — 4× Scaling Bug +# CURRENT ISSUE: UMMA FMHA — Multi-K-tile + PV GEMM + Full Pipeline -## What's working -- UMMA SMEM descriptors: K-major NONE, LBO=BLOCK_MN*16, SBO=128 -- SMEM canonical layout: column-major interleaving of 8×8 BF16 core matrices -- Q and K SMEM data verified EXACT match with originals -- tcgen05.mma produces non-zero output — descriptor and data layout are valid -- TMEM Layout D read with tcgen05.ld.32x32b.x8 works (no crash) -- TMEM alloc/dealloc works +## What's working ✅ +- **UMMA QK GEMM at HD=16, SK=128**: Row 0 matches scalar reference with ZERO error +- **SMEM canonical layout**: column-major interleaving of 8×8 BF16 core matrices +- **K-major NONE descriptors**: LBO=BLOCK_MN*16, SBO=128, lbo_mode=0, layout_type=0 +- **TMEM Layout D reads**: `tcgen05.ld.32x32b.x8.b32` with `addr = tmem_base + (row<<16) + col` +- **MMA→TMEM fence**: `tcgen05.fence::after_thread_sync` (not `tcgen05.wait::st`) +- **MMA computes raw dot product** — apply 1/sqrt(HD) scaling in the read path -## The 4× Bug -MMA output is exactly 4× the scalar reference for ALL output values. -- S[0,0] MMA = 0.1529, scalar = 0.0382, ratio = 4.0000 -- Persists with different N in idesc (8, 32, 128) -- Persists with 4 warp leaders calling MMA (vs 1 thread) -- Persists with 8KB zero padding between Q and K in SMEM +## Next steps +1. **HD=64 multi-K-tile**: Call MMA 4× with accumulate=true for K=64 (4 × K=16 tiles) + - Each K-tile needs its own descriptor pointing to the right 16-column slice + - gau-nernst pattern: `A_smem + k * BLOCK_M * 32` for the k-th K-tile start address + - After all K-tiles: read TMEM and apply 1/sqrt(HD) scaling -### Root cause hypothesis -The MMA with cta_group::1 and M=128 uses 4 "warpgroups" internally (Layout D). -The TMEM output is written in a format where each warpgroup contributes to -different rows. When we read with 32x32b.x8 (warp 0, rows 0-31), we get -the correct S[0,0] but multiplied by 4 because the MMA accumulates contributions -from all 4 warpgroups into the same TMEM columns. +2. **PV GEMM**: `tcgen05.mma TS` (TMEM P × SMEM V → TMEM O) + - P is in TMEM after softmax, V is in SMEM + - Accumulate O across KV tiles with the D5 merge formula -Alternatively: the TMEM Layout D has a specific column mapping that we're not -accounting for. The MMA output columns might not correspond 1:1 with the -attention score columns. +3. **In-kernel softmax**: TMEM → regs → max/exp/sum → TMEM + - Use 32x32b reads to get S, compute softmax, write P back via 32x32b stores + - Must handle the TMEM multi-store issue (use 32x32b, not 16x256b) -### How to fix -1. Study CUTLASS FMHA Python reference (fmha.py on B200) for TMEM output layout -2. Check if the 4× factor is a known issue with single-CTA MMA -3. Try M=64 (2 warpgroups) — should give 2× if warpgroup count is the cause -4. Look at gau-nernst's GEMM example to see how he reads the MMA output -5. Check if the MMA output needs to be divided by the number of warpgroups +4. **Full FMHA pipeline**: QK → softmax → PV → correction epilogue → GMEM output -## TMEM multi-store bug -Calling tcgen05.st.16x256b.x1.b32 more than once causes "misaligned address". -- Single store: works -- 2+ stores: crash (even with fence+sync between them) -- CUTLASS uses different TMEM store atoms (St32x32bOp) -- Need to investigate: is 16x256b.x1 not meant for multiple stores? +## Key lessons learned +- **16x256b.x1 TMEM stores crash on 2nd call** — use 32x32b format for multi-store +- **MMA output is UNSCALED** — the 4× "bug" was just the 1/sqrt(HD) attention scale +- **`tcgen05.fence::after_thread_sync`** is the correct MMA→TMEM load fence +- **4 warps minimum** for M=128 Layout D (each warp reads 32 rows × 8 columns) +- **MMA K-tile size is 16 BF16** — for HD>16, loop with accumulate +- **TMEM address format**: bits [31:16] = row, bits [15:0] = column ## Files -- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptor construction, write_smem_* -- `tests/unit/test_umma_qk.cu` — UMMA QK GEMM test (HD=16, SK=128) -- `tests/unit/test_tmem_cols.cu` — TMEM multi-store debug test - -## Key references -- gau-nernst tcgen05 tutorial: https://gau-nernst.github.io/tcgen05/ -- CUTLASS SM100 UMMA: include/cute/arch/mma_sm100_umma.hpp -- CUTLASS InstrDescriptor: include/cute/arch/mma_sm100_desc.hpp -- CUTLASS FMHA reference on B200: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py +- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptors, SMEM layout, MMA wrappers +- `tests/unit/test_umma_qk.cu` — working UMMA QK GEMM test