Files
nvfp4-megamoe-kernel/docs/p7_tmem_column_layout.md
biondizzle e747742598 P7: Document TMEM column layout, add multi-row softmax test
docs/p7_tmem_column_layout.md: Verified that tcgen05.ld 32x32b.x8 is
the correct instruction for multi-row softmax. Each call reads 8 KV
positions for 32 rows. No instruction change needed from single-row.

test_p7_multi_row_softmax.py: Tests T=1,4,32,64,128 at various HD and N.
Gate: cos >= 0.999996.
2026-05-30 17:17:54 +00:00

73 lines
2.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# P7: TMEM Column Layout for Multi-Row Softmax
## Observed Layout (verified on B200)
The FMHA QK MMA produces a TMEM tensor S of shape (128, s_k) in row-major layout:
- Row 0: QK dot product for query position 0 (128 BF16 → 128 FP32 in TMEM)
- Row 1: QK dot product for query position 1
- ...
- Row T-1: Only T rows have valid data (T ≤ 128 for single CTA)
### TMEM Organization
For `tcgen05.mma.kind::f16` with M=128, N=16 (single PV sub-tile):
- MMA writes to TMEM at column offset `n_sub * 16` where n_sub = 0..N_NSUB-1
- Each PV sub-tile writes 16 TMEM columns
For QK GEMM (M=128, N=128):
- QK writes to TMEM columns 0..127 (128 columns)
- For HD=64: TMEM_N = 128 columns allocated
- For HD=128: TMEM_N = 128 columns allocated
- For HD=256: TMEM_N = 256 columns allocated
### TMEM Read: tcgen05.ld.sync.aligned.32x32b.x8.b32
**Format:** Each call reads 8 consecutive TMEM columns for all 32 lanes.
```
addr = tmem_base + n * 8
```
Where `n` is the "step" index (0, 8, 16, ...).
**Lane mapping:** For step `n`, lane `i` reads 8 FP32 values from columns `n` through `n+7`, **row `i`** of each column.
- Lane 0 reads S[0, n*1] through S[0, n*1+7] (row 0)
- Lane 1 reads S[1, n*1] through S[1, n*1+7] (row 1)
- ...
- Lane 31 reads S[31, n*1] through S[31, n*1+7] (row 31)
This means:
- One `32x32b.x8` call reads 8 KV positions for 32 query rows simultaneously
- The instruction IS the correct one for multi-row softmax
- Each warp (32 lanes) processes 32 consecutive query rows
- 4 warps (lanes 0-127) process 128 query rows total
### Multi-Row Softmax Strategy
For T ≤ 32: 1 warp (warp 0) processes all rows
- my_row = lane (0..31)
- Each lane computes softmax for its own row
For T ≤ 64: 2 warps (warps 0-1)
- Warp 0: rows 0-31, Warp 1: rows 32-63
- my_row = wid * 32 + lane
For T ≤ 128: 4 warps (warps 0-3)
- Each warp processes 32 rows
- my_row = wid * 32 + lane
This is exactly what the multi-tile kernel (`fmha_6warp_tma_multirow_multitile.cuh`) implements.
### Key Insight
The `32x32b.x8` instruction is already correct for multi-row softmax. No instruction change needed. The "use 16x256b.x1" guess from earlier was WRONG — that instruction reads 16 rows with 8 FP32 per row (4 FP32 per lane for 2 rows), which is more complex to use and doesn't match the S tensor layout.
The `32x32b.x8` reads 8 KV positions for 32 rows per call — perfect for row-wise softmax where we need to compute (max, exp, sum) per row across all KV positions.
### Verified Results
All 72 configs pass in the multi-tile kernel:
- HD=64/128/256/512 × T=1/4/32/128 × s_k=128/256/384/512
- Cos ≥ 0.999996 across all configs