UMMA QK GEMM WORKING! Update docs — 4x was scale factor, not bug
Major milestone: UMMA QK GEMM produces correct attention scores at HD=16! - MMA computes raw dot product; apply 1/sqrt(HD) scaling manually - tcgen05.fence::after_thread_sync for MMA→TMEM fence - 32x32b.x8 TMEM reads for Layout D output - 4 warps (128 threads) required for M=128 - Next: HD=64 multi-K-tile, PV GEMM, full FMHA pipeline
This commit is contained in:
@@ -1,52 +1,37 @@
|
||||
# CURRENT ISSUE: UMMA QK GEMM — 4× Scaling Bug
|
||||
# CURRENT ISSUE: UMMA FMHA — Multi-K-tile + PV GEMM + Full Pipeline
|
||||
|
||||
## What's working
|
||||
- UMMA SMEM descriptors: K-major NONE, LBO=BLOCK_MN*16, SBO=128
|
||||
- SMEM canonical layout: column-major interleaving of 8×8 BF16 core matrices
|
||||
- Q and K SMEM data verified EXACT match with originals
|
||||
- tcgen05.mma produces non-zero output — descriptor and data layout are valid
|
||||
- TMEM Layout D read with tcgen05.ld.32x32b.x8 works (no crash)
|
||||
- TMEM alloc/dealloc works
|
||||
## What's working ✅
|
||||
- **UMMA QK GEMM at HD=16, SK=128**: Row 0 matches scalar reference with ZERO error
|
||||
- **SMEM canonical layout**: column-major interleaving of 8×8 BF16 core matrices
|
||||
- **K-major NONE descriptors**: LBO=BLOCK_MN*16, SBO=128, lbo_mode=0, layout_type=0
|
||||
- **TMEM Layout D reads**: `tcgen05.ld.32x32b.x8.b32` with `addr = tmem_base + (row<<16) + col`
|
||||
- **MMA→TMEM fence**: `tcgen05.fence::after_thread_sync` (not `tcgen05.wait::st`)
|
||||
- **MMA computes raw dot product** — apply 1/sqrt(HD) scaling in the read path
|
||||
|
||||
## The 4× Bug
|
||||
MMA output is exactly 4× the scalar reference for ALL output values.
|
||||
- S[0,0] MMA = 0.1529, scalar = 0.0382, ratio = 4.0000
|
||||
- Persists with different N in idesc (8, 32, 128)
|
||||
- Persists with 4 warp leaders calling MMA (vs 1 thread)
|
||||
- Persists with 8KB zero padding between Q and K in SMEM
|
||||
## Next steps
|
||||
1. **HD=64 multi-K-tile**: Call MMA 4× with accumulate=true for K=64 (4 × K=16 tiles)
|
||||
- Each K-tile needs its own descriptor pointing to the right 16-column slice
|
||||
- gau-nernst pattern: `A_smem + k * BLOCK_M * 32` for the k-th K-tile start address
|
||||
- After all K-tiles: read TMEM and apply 1/sqrt(HD) scaling
|
||||
|
||||
### Root cause hypothesis
|
||||
The MMA with cta_group::1 and M=128 uses 4 "warpgroups" internally (Layout D).
|
||||
The TMEM output is written in a format where each warpgroup contributes to
|
||||
different rows. When we read with 32x32b.x8 (warp 0, rows 0-31), we get
|
||||
the correct S[0,0] but multiplied by 4 because the MMA accumulates contributions
|
||||
from all 4 warpgroups into the same TMEM columns.
|
||||
2. **PV GEMM**: `tcgen05.mma TS` (TMEM P × SMEM V → TMEM O)
|
||||
- P is in TMEM after softmax, V is in SMEM
|
||||
- Accumulate O across KV tiles with the D5 merge formula
|
||||
|
||||
Alternatively: the TMEM Layout D has a specific column mapping that we're not
|
||||
accounting for. The MMA output columns might not correspond 1:1 with the
|
||||
attention score columns.
|
||||
3. **In-kernel softmax**: TMEM → regs → max/exp/sum → TMEM
|
||||
- Use 32x32b reads to get S, compute softmax, write P back via 32x32b stores
|
||||
- Must handle the TMEM multi-store issue (use 32x32b, not 16x256b)
|
||||
|
||||
### How to fix
|
||||
1. Study CUTLASS FMHA Python reference (fmha.py on B200) for TMEM output layout
|
||||
2. Check if the 4× factor is a known issue with single-CTA MMA
|
||||
3. Try M=64 (2 warpgroups) — should give 2× if warpgroup count is the cause
|
||||
4. Look at gau-nernst's GEMM example to see how he reads the MMA output
|
||||
5. Check if the MMA output needs to be divided by the number of warpgroups
|
||||
4. **Full FMHA pipeline**: QK → softmax → PV → correction epilogue → GMEM output
|
||||
|
||||
## TMEM multi-store bug
|
||||
Calling tcgen05.st.16x256b.x1.b32 more than once causes "misaligned address".
|
||||
- Single store: works
|
||||
- 2+ stores: crash (even with fence+sync between them)
|
||||
- CUTLASS uses different TMEM store atoms (St32x32bOp)
|
||||
- Need to investigate: is 16x256b.x1 not meant for multiple stores?
|
||||
## Key lessons learned
|
||||
- **16x256b.x1 TMEM stores crash on 2nd call** — use 32x32b format for multi-store
|
||||
- **MMA output is UNSCALED** — the 4× "bug" was just the 1/sqrt(HD) attention scale
|
||||
- **`tcgen05.fence::after_thread_sync`** is the correct MMA→TMEM load fence
|
||||
- **4 warps minimum** for M=128 Layout D (each warp reads 32 rows × 8 columns)
|
||||
- **MMA K-tile size is 16 BF16** — for HD>16, loop with accumulate
|
||||
- **TMEM address format**: bits [31:16] = row, bits [15:0] = column
|
||||
|
||||
## Files
|
||||
- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptor construction, write_smem_*
|
||||
- `tests/unit/test_umma_qk.cu` — UMMA QK GEMM test (HD=16, SK=128)
|
||||
- `tests/unit/test_tmem_cols.cu` — TMEM multi-store debug test
|
||||
|
||||
## Key references
|
||||
- gau-nernst tcgen05 tutorial: https://gau-nernst.github.io/tcgen05/
|
||||
- CUTLASS SM100 UMMA: include/cute/arch/mma_sm100_umma.hpp
|
||||
- CUTLASS InstrDescriptor: include/cute/arch/mma_sm100_desc.hpp
|
||||
- CUTLASS FMHA reference on B200: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
|
||||
- `dsv4/kernels/attention/fmha_umma_desc.cuh` — descriptors, SMEM layout, MMA wrappers
|
||||
- `tests/unit/test_umma_qk.cu` — working UMMA QK GEMM test
|
||||
|
||||
Reference in New Issue
Block a user