Files
nvfp4-megamoe-kernel/CURRENT_ISSUE.md
biondizzle df34cae9c6 UMMA QK GEMM WORKING! Update docs — 4x was scale factor, not bug
Major milestone: UMMA QK GEMM produces correct attention scores at HD=16!
- MMA computes raw dot product; apply 1/sqrt(HD) scaling manually
- tcgen05.fence::after_thread_sync for MMA→TMEM fence
- 32x32b.x8 TMEM reads for Layout D output
- 4 warps (128 threads) required for M=128
- Next: HD=64 multi-K-tile, PV GEMM, full FMHA pipeline
2026-05-28 11:41:19 +00:00

2.0 KiB
Raw Blame History

CURRENT ISSUE: UMMA FMHA — Multi-K-tile + PV GEMM + Full Pipeline

What's working

  • UMMA QK GEMM at HD=16, SK=128: Row 0 matches scalar reference with ZERO error
  • SMEM canonical layout: column-major interleaving of 8×8 BF16 core matrices
  • K-major NONE descriptors: LBO=BLOCK_MN*16, SBO=128, lbo_mode=0, layout_type=0
  • TMEM Layout D reads: tcgen05.ld.32x32b.x8.b32 with addr = tmem_base + (row<<16) + col
  • MMA→TMEM fence: tcgen05.fence::after_thread_sync (not tcgen05.wait::st)
  • MMA computes raw dot product — apply 1/sqrt(HD) scaling in the read path

Next steps

  1. HD=64 multi-K-tile: Call MMA 4× with accumulate=true for K=64 (4 × K=16 tiles)

    • Each K-tile needs its own descriptor pointing to the right 16-column slice
    • gau-nernst pattern: A_smem + k * BLOCK_M * 32 for the k-th K-tile start address
    • After all K-tiles: read TMEM and apply 1/sqrt(HD) scaling
  2. PV GEMM: tcgen05.mma TS (TMEM P × SMEM V → TMEM O)

    • P is in TMEM after softmax, V is in SMEM
    • Accumulate O across KV tiles with the D5 merge formula
  3. In-kernel softmax: TMEM → regs → max/exp/sum → TMEM

    • Use 32x32b reads to get S, compute softmax, write P back via 32x32b stores
    • Must handle the TMEM multi-store issue (use 32x32b, not 16x256b)
  4. Full FMHA pipeline: QK → softmax → PV → correction epilogue → GMEM output

Key lessons learned

  • 16x256b.x1 TMEM stores crash on 2nd call — use 32x32b format for multi-store
  • MMA output is UNSCALED — the 4× "bug" was just the 1/sqrt(HD) attention scale
  • tcgen05.fence::after_thread_sync is the correct MMA→TMEM load fence
  • 4 warps minimum for M=128 Layout D (each warp reads 32 rows × 8 columns)
  • MMA K-tile size is 16 BF16 — for HD>16, loop with accumulate
  • TMEM address format: bits [31:16] = row, bits [15:0] = column

Files

  • dsv4/kernels/attention/fmha_umma_desc.cuh — descriptors, SMEM layout, MMA wrappers
  • tests/unit/test_umma_qk.cu — working UMMA QK GEMM test