Files
nvfp4-megamoe-kernel/CURRENT_ISSUE.md
2026-05-28 21:08:13 +00:00

4.5 KiB

CURRENT_ISSUE.md — FMHA 6-Warp Specialization

Status: Milestone 4 IN PROGRESS (multi-row softmax for prefill T>1)

Milestone 5 DONE — multi-head grid launch

Milestone 4: T≤32 PASSING (cos 0.999996+), T>32 BLOCKED on TMEM row read

CRITICAL BUG FIXED: Q/K SMEM canonical layout used full_d instead of local d (0..15)

What works:

  • 6-warp kernel: Warps 0-3 softmax/epilogue, Warp 4 MMA, Warp 5 data staging
  • All HD values: HD=16/64/128/256 pass with cos 0.999997+
  • Warp role separation: MMA and data loading on separate warps
  • CTA-wide sync: __syncthreads() between phases
  • Multi-head grid launch: grid=(1, n_h, batch), each CTA handles one head
  • MQA: k_head_stride=0 / v_head_stride=0 for shared KV heads
  • Batched: blockIdx.z for batch dimension
  • LSE output: per-row LSE for multi-segment KV merge
  • FmhaParams struct: stride-based tensor addressing, future-proof for GQA
  • Multi-row softmax T≤32: cos 0.999996+ with per-lane per-row softmax (no wmax/wsum)

Architecture:

Warp 0-3 (tid 0-127):  Softmax + correction + epilogue
  - Read S from TMEM → softmax → write P to SMEM
  - After PV: read O from TMEM → BF16 → GMEM
  - T=1 decode: only warp 0 processes row 0
Warp 4 (tid 128-159): MMA
  - tcgen05.mma SS for QK (N=128) and PV (N=16 sub-tiles)
  - TMEM alloc/dealloc
Warp 5 (tid 160-191): Data staging
  - Load Q/K/V from GMEM to SMEM (canonical layout)
  - Fill sPk from s_p_vals

Next milestones:

  1. TMA loads (Milestone 2): Replace direct GMEM reads with cp.async.bulk.tensor
    • Requires CUtensorMap creation on host
    • mbarrier synchronization
    • BLOCKED: cuTensorMapEncodeTiled 2D/3D/5D returns INVALID_VALUE on B200 driver v580.126.20
    • Alternative: Study CuTeDSL's TMA descriptor creation source code
  2. Pipeline overlap (Milestone 3): Double-buffer K/V loads
    • Load next K/V while computing current QK
    • mbarrier producer-consumer sync between warp 5 and warp 4
    • Depends on TMA loads (Milestone 2)
  3. Multi-row softmax (Milestone 4): Process all 128 rows (prefill T>1) 🚧 IN PROGRESS
    • T≤32: WORKING — warp 0, lane l handles row l, 32x32b.x8 TMEM reads
    • T>32: BLOCKED — 32x32b.x8 only reads rows 0-31
    • NEXT: Use 16x256b.x1 TMEM reads (reads all 128 rows per column)
    • Each of 4 softmax warps handles rows [w32, (w+132) ∩ [0, T)
    • Per-lane row assignment in 16x256b: lane j gets rows j*4+0..3
    • No cross-warp reduction needed (disjoint row sets)
    • KEY LESSON: Q/K SMEM canonical positions MUST use local d (0..15), NOT full_d The UMMA descriptor always reads from sQ0/sK0 start, not offset
  4. Multi-head launch (Milestone 5): grid=(1, n_h, batch) DONE
  5. Production integration (Milestone 6): Hook into production.py

Files:

  • dsv4/kernels/attention/fmha_6warp.cuh — 6-warp kernel (single-head)
  • dsv4/kernels/attention/fmha_6warp_multihead.cuh — Multi-head grid launch kernel
  • tests/unit/test_fmha_6warp_multihead.cu — Multi-head test harness
  • tests/unit/test_fmha_6warp_multihead_hd{16,64,128,256}.cu — HD-specific wrappers

Layout D N=64 Bug (documented for NVIDIA):

  • tcgen05.mma with make_idesc(128, 64) skips TMEM cols 32-35, 48-51
  • Workaround: N=16 sub-tiles with TMEM offset n*16

Mike, I've hit a wall with TMA. Here's the situation:

  1. Raw cuTensorMapEncodeTiled doesn't work for ANY multi-dimensional descriptor on the B200. 1D works, 2D/3D/5D all return INVALID_VALUE. This might be a driver issue (v580.126.20) or a parameter format I haven't figured out.

  2. CuTeDSL's TMA works perfectly (the existing FMHA kernel uses it), but I can't mix it with raw CUDA inline PTX easily.

  3. The CuTeDSL FMHA with pv_n_tile=16 crashes (illegal memory access) — needs CuTeDSL debugging.

What I recommend: focus on what works. The raw CUDA 6-warp kernel is fully working at all HD values with direct GMEM reads. The TMA optimization can wait for:

  • A driver update that fixes cuTensorMapEncodeTiled for multi-dimensional descriptors, OR
  • Debugging the CuTeDSL pv_n_tile=16 crash (CuTeDSL-specific, different skill set), OR
  • Using CuTeDSL's TMA in a separate staging kernel

Should I:

  • A) Continue debugging the CuTeDSL pv_n_tile=16 crash (might take a while, but gives us TMA + correct PV)
  • B) Move on to multi-head launch (more immediate production impact, no TMA needed)
  • C) Try to find the correct cuTensorMapEncodeTiled parameters by studying the CuTeDSL source code that creates the descriptors