Files
nvfp4-megamoe-kernel/docs/p7_tmem_column_layout.md
biondizzle e747742598 P7: Document TMEM column layout, add multi-row softmax test
docs/p7_tmem_column_layout.md: Verified that tcgen05.ld 32x32b.x8 is
the correct instruction for multi-row softmax. Each call reads 8 KV
positions for 32 rows. No instruction change needed from single-row.

test_p7_multi_row_softmax.py: Tests T=1,4,32,64,128 at various HD and N.
Gate: cos >= 0.999996.
2026-05-30 17:17:54 +00:00

2.6 KiB
Raw Blame History

P7: TMEM Column Layout for Multi-Row Softmax

Observed Layout (verified on B200)

The FMHA QK MMA produces a TMEM tensor S of shape (128, s_k) in row-major layout:

  • Row 0: QK dot product for query position 0 (128 BF16 → 128 FP32 in TMEM)
  • Row 1: QK dot product for query position 1
  • ...
  • Row T-1: Only T rows have valid data (T ≤ 128 for single CTA)

TMEM Organization

For tcgen05.mma.kind::f16 with M=128, N=16 (single PV sub-tile):

  • MMA writes to TMEM at column offset n_sub * 16 where n_sub = 0..N_NSUB-1
  • Each PV sub-tile writes 16 TMEM columns

For QK GEMM (M=128, N=128):

  • QK writes to TMEM columns 0..127 (128 columns)
  • For HD=64: TMEM_N = 128 columns allocated
  • For HD=128: TMEM_N = 128 columns allocated
  • For HD=256: TMEM_N = 256 columns allocated

TMEM Read: tcgen05.ld.sync.aligned.32x32b.x8.b32

Format: Each call reads 8 consecutive TMEM columns for all 32 lanes.

addr = tmem_base + n * 8

Where n is the "step" index (0, 8, 16, ...).

Lane mapping: For step n, lane i reads 8 FP32 values from columns n through n+7, row i of each column.

  • Lane 0 reads S[0, n1] through S[0, n1+7] (row 0)
  • Lane 1 reads S[1, n1] through S[1, n1+7] (row 1)
  • ...
  • Lane 31 reads S[31, n1] through S[31, n1+7] (row 31)

This means:

  • One 32x32b.x8 call reads 8 KV positions for 32 query rows simultaneously
  • The instruction IS the correct one for multi-row softmax
  • Each warp (32 lanes) processes 32 consecutive query rows
  • 4 warps (lanes 0-127) process 128 query rows total

Multi-Row Softmax Strategy

For T ≤ 32: 1 warp (warp 0) processes all rows

  • my_row = lane (0..31)
  • Each lane computes softmax for its own row

For T ≤ 64: 2 warps (warps 0-1)

  • Warp 0: rows 0-31, Warp 1: rows 32-63
  • my_row = wid * 32 + lane

For T ≤ 128: 4 warps (warps 0-3)

  • Each warp processes 32 rows
  • my_row = wid * 32 + lane

This is exactly what the multi-tile kernel (fmha_6warp_tma_multirow_multitile.cuh) implements.

Key Insight

The 32x32b.x8 instruction is already correct for multi-row softmax. No instruction change needed. The "use 16x256b.x1" guess from earlier was WRONG — that instruction reads 16 rows with 8 FP32 per row (4 FP32 per lane for 2 rows), which is more complex to use and doesn't match the S tensor layout.

The 32x32b.x8 reads 8 KV positions for 32 rows per call — perfect for row-wise softmax where we need to compute (max, exp, sum) per row across all KV positions.

Verified Results

All 72 configs pass in the multi-tile kernel:

  • HD=64/128/256/512 × T=1/4/32/128 × s_k=128/256/384/512
  • Cos ≥ 0.999996 across all configs