UMMA QK GEMM WORKING! Update docs — 4x was scale factor, not bug

Major milestone: UMMA QK GEMM produces correct attention scores at HD=16!
- MMA computes raw dot product; apply 1/sqrt(HD) scaling manually
- tcgen05.fence::after_thread_sync for MMA→TMEM fence
- 32x32b.x8 TMEM reads for Layout D output
- 4 warps (128 threads) required for M=128
- Next: HD=64 multi-K-tile, PV GEMM, full FMHA pipeline
This commit is contained in:
2026-05-28 11:41:19 +00:00
parent 1874a70a6d
commit df34cae9c6

View File

@@ -1,52 +1,37 @@
# CURRENT ISSUE: UMMA QK GEMM — 4× Scaling Bug
# CURRENT ISSUE: UMMA FMHA — Multi-K-tile + PV GEMM + Full Pipeline
## What's working
- UMMA SMEM descriptors: K-major NONE, LBO=BLOCK_MN*16, SBO=128
- SMEM canonical layout: column-major interleaving of 8×8 BF16 core matrices
- Q and K SMEM data verified EXACT match with originals
- tcgen05.mma produces non-zero output — descriptor and data layout are valid
- TMEM Layout D read with tcgen05.ld.32x32b.x8 works (no crash)
- TMEM alloc/dealloc works
## What's working
- **UMMA QK GEMM at HD=16, SK=128**: Row 0 matches scalar reference with ZERO error
- **SMEM canonical layout**: column-major interleaving of 8×8 BF16 core matrices
- **K-major NONE descriptors**: LBO=BLOCK_MN*16, SBO=128, lbo_mode=0, layout_type=0
- **TMEM Layout D reads**: `tcgen05.ld.32x32b.x8.b32` with `addr = tmem_base + (row<<16) + col`
- **MMA→TMEM fence**: `tcgen05.fence::after_thread_sync` (not `tcgen05.wait::st`)
- **MMA computes raw dot product** — apply 1/sqrt(HD) scaling in the read path
## The 4× Bug
MMA output is exactly 4× the scalar reference for ALL output values.
- S[0,0] MMA = 0.1529, scalar = 0.0382, ratio = 4.0000
- Persists with different N in idesc (8, 32, 128)
- Persists with 4 warp leaders calling MMA (vs 1 thread)
- Persists with 8KB zero padding between Q and K in SMEM
## Next steps
1. **HD=64 multi-K-tile**: Call MMA 4× with accumulate=true for K=64 (4 × K=16 tiles)
- Each K-tile needs its own descriptor pointing to the right 16-column slice
- gau-nernst pattern: `A_smem + k * BLOCK_M * 32` for the k-th K-tile start address
- After all K-tiles: read TMEM and apply 1/sqrt(HD) scaling
### Root cause hypothesis
The MMA with cta_group::1 and M=128 uses 4 "warpgroups" internally (Layout D).
The TMEM output is written in a format where each warpgroup contributes to
different rows. When we read with 32x32b.x8 (warp 0, rows 0-31), we get
the correct S[0,0] but multiplied by 4 because the MMA accumulates contributions
from all 4 warpgroups into the same TMEM columns.
2. **PV GEMM**: `tcgen05.mma TS` (TMEM P × SMEM V → TMEM O)
- P is in TMEM after softmax, V is in SMEM
- Accumulate O across KV tiles with the D5 merge formula
Alternatively: the TMEM Layout D has a specific column mapping that we're not
accounting for. The MMA output columns might not correspond 1:1 with the
attention score columns.
3. **In-kernel softmax**: TMEM → regs → max/exp/sum → TMEM
- Use 32x32b reads to get S, compute softmax, write P back via 32x32b stores
- Must handle the TMEM multi-store issue (use 32x32b, not 16x256b)
### How to fix
1. Study CUTLASS FMHA Python reference (fmha.py on B200) for TMEM output layout
2. Check if the 4× factor is a known issue with single-CTA MMA
3. Try M=64 (2 warpgroups) — should give 2× if warpgroup count is the cause
4. Look at gau-nernst's GEMM example to see how he reads the MMA output
5. Check if the MMA output needs to be divided by the number of warpgroups
4. **Full FMHA pipeline**: QK → softmax → PV → correction epilogue → GMEM output
## TMEM multi-store bug
Calling tcgen05.st.16x256b.x1.b32 more than once causes "misaligned address".
- Single store: works
- 2+ stores: crash (even with fence+sync between them)
- CUTLASS uses different TMEM store atoms (St32x32bOp)
- Need to investigate: is 16x256b.x1 not meant for multiple stores?
## Key lessons learned
- **16x256b.x1 TMEM stores crash on 2nd call** — use 32x32b format for multi-store
- **MMA output is UNSCALED** — the 4× "bug" was just the 1/sqrt(HD) attention scale
- **`tcgen05.fence::after_thread_sync`** is the correct MMA→TMEM load fence
- **4 warps minimum** for M=128 Layout D (each warp reads 32 rows × 8 columns)
- **MMA K-tile size is 16 BF16** — for HD>16, loop with accumulate
- **TMEM address format**: bits [31:16] = row, bits [15:0] = column
## Files
- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptor construction, write_smem_*
- `tests/unit/test_umma_qk.cu` — UMMA QK GEMM test (HD=16, SK=128)
- `tests/unit/test_tmem_cols.cu` — TMEM multi-store debug test
## Key references
- gau-nernst tcgen05 tutorial: https://gau-nernst.github.io/tcgen05/
- CUTLASS SM100 UMMA: include/cute/arch/mma_sm100_umma.hpp
- CUTLASS InstrDescriptor: include/cute/arch/mma_sm100_desc.hpp
- CUTLASS FMHA reference on B200: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptors, SMEM layout, MMA wrappers
- `tests/unit/test_umma_qk.cu` working UMMA QK GEMM test