Major milestone: UMMA QK GEMM produces correct attention scores at HD=16! - MMA computes raw dot product; apply 1/sqrt(HD) scaling manually - tcgen05.fence::after_thread_sync for MMA→TMEM fence - 32x32b.x8 TMEM reads for Layout D output - 4 warps (128 threads) required for M=128 - Next: HD=64 multi-K-tile, PV GEMM, full FMHA pipeline
2.0 KiB
2.0 KiB
CURRENT ISSUE: UMMA FMHA — Multi-K-tile + PV GEMM + Full Pipeline
What's working ✅
- UMMA QK GEMM at HD=16, SK=128: Row 0 matches scalar reference with ZERO error
- SMEM canonical layout: column-major interleaving of 8×8 BF16 core matrices
- K-major NONE descriptors: LBO=BLOCK_MN*16, SBO=128, lbo_mode=0, layout_type=0
- TMEM Layout D reads:
tcgen05.ld.32x32b.x8.b32withaddr = tmem_base + (row<<16) + col - MMA→TMEM fence:
tcgen05.fence::after_thread_sync(nottcgen05.wait::st) - MMA computes raw dot product — apply 1/sqrt(HD) scaling in the read path
Next steps
-
HD=64 multi-K-tile: Call MMA 4× with accumulate=true for K=64 (4 × K=16 tiles)
- Each K-tile needs its own descriptor pointing to the right 16-column slice
- gau-nernst pattern:
A_smem + k * BLOCK_M * 32for the k-th K-tile start address - After all K-tiles: read TMEM and apply 1/sqrt(HD) scaling
-
PV GEMM:
tcgen05.mma TS(TMEM P × SMEM V → TMEM O)- P is in TMEM after softmax, V is in SMEM
- Accumulate O across KV tiles with the D5 merge formula
-
In-kernel softmax: TMEM → regs → max/exp/sum → TMEM
- Use 32x32b reads to get S, compute softmax, write P back via 32x32b stores
- Must handle the TMEM multi-store issue (use 32x32b, not 16x256b)
-
Full FMHA pipeline: QK → softmax → PV → correction epilogue → GMEM output
Key lessons learned
- 16x256b.x1 TMEM stores crash on 2nd call — use 32x32b format for multi-store
- MMA output is UNSCALED — the 4× "bug" was just the 1/sqrt(HD) attention scale
tcgen05.fence::after_thread_syncis the correct MMA→TMEM load fence- 4 warps minimum for M=128 Layout D (each warp reads 32 rows × 8 columns)
- MMA K-tile size is 16 BF16 — for HD>16, loop with accumulate
- TMEM address format: bits [31:16] = row, bits [15:0] = column
Files
dsv4/kernels/attention/fmha_umma_desc.cuh— descriptors, SMEM layout, MMA wrapperstests/unit/test_umma_qk.cu— working UMMA QK GEMM test