73 lines
2.6 KiB
Markdown
73 lines
2.6 KiB
Markdown
|
|
# P7: TMEM Column Layout for Multi-Row Softmax
|
|||
|
|
|
|||
|
|
## Observed Layout (verified on B200)
|
|||
|
|
|
|||
|
|
The FMHA QK MMA produces a TMEM tensor S of shape (128, s_k) in row-major layout:
|
|||
|
|
- Row 0: QK dot product for query position 0 (128 BF16 → 128 FP32 in TMEM)
|
|||
|
|
- Row 1: QK dot product for query position 1
|
|||
|
|
- ...
|
|||
|
|
- Row T-1: Only T rows have valid data (T ≤ 128 for single CTA)
|
|||
|
|
|
|||
|
|
### TMEM Organization
|
|||
|
|
|
|||
|
|
For `tcgen05.mma.kind::f16` with M=128, N=16 (single PV sub-tile):
|
|||
|
|
- MMA writes to TMEM at column offset `n_sub * 16` where n_sub = 0..N_NSUB-1
|
|||
|
|
- Each PV sub-tile writes 16 TMEM columns
|
|||
|
|
|
|||
|
|
For QK GEMM (M=128, N=128):
|
|||
|
|
- QK writes to TMEM columns 0..127 (128 columns)
|
|||
|
|
- For HD=64: TMEM_N = 128 columns allocated
|
|||
|
|
- For HD=128: TMEM_N = 128 columns allocated
|
|||
|
|
- For HD=256: TMEM_N = 256 columns allocated
|
|||
|
|
|
|||
|
|
### TMEM Read: tcgen05.ld.sync.aligned.32x32b.x8.b32
|
|||
|
|
|
|||
|
|
**Format:** Each call reads 8 consecutive TMEM columns for all 32 lanes.
|
|||
|
|
|
|||
|
|
```
|
|||
|
|
addr = tmem_base + n * 8
|
|||
|
|
```
|
|||
|
|
|
|||
|
|
Where `n` is the "step" index (0, 8, 16, ...).
|
|||
|
|
|
|||
|
|
**Lane mapping:** For step `n`, lane `i` reads 8 FP32 values from columns `n` through `n+7`, **row `i`** of each column.
|
|||
|
|
|
|||
|
|
- Lane 0 reads S[0, n*1] through S[0, n*1+7] (row 0)
|
|||
|
|
- Lane 1 reads S[1, n*1] through S[1, n*1+7] (row 1)
|
|||
|
|
- ...
|
|||
|
|
- Lane 31 reads S[31, n*1] through S[31, n*1+7] (row 31)
|
|||
|
|
|
|||
|
|
This means:
|
|||
|
|
- One `32x32b.x8` call reads 8 KV positions for 32 query rows simultaneously
|
|||
|
|
- The instruction IS the correct one for multi-row softmax
|
|||
|
|
- Each warp (32 lanes) processes 32 consecutive query rows
|
|||
|
|
- 4 warps (lanes 0-127) process 128 query rows total
|
|||
|
|
|
|||
|
|
### Multi-Row Softmax Strategy
|
|||
|
|
|
|||
|
|
For T ≤ 32: 1 warp (warp 0) processes all rows
|
|||
|
|
- my_row = lane (0..31)
|
|||
|
|
- Each lane computes softmax for its own row
|
|||
|
|
|
|||
|
|
For T ≤ 64: 2 warps (warps 0-1)
|
|||
|
|
- Warp 0: rows 0-31, Warp 1: rows 32-63
|
|||
|
|
- my_row = wid * 32 + lane
|
|||
|
|
|
|||
|
|
For T ≤ 128: 4 warps (warps 0-3)
|
|||
|
|
- Each warp processes 32 rows
|
|||
|
|
- my_row = wid * 32 + lane
|
|||
|
|
|
|||
|
|
This is exactly what the multi-tile kernel (`fmha_6warp_tma_multirow_multitile.cuh`) implements.
|
|||
|
|
|
|||
|
|
### Key Insight
|
|||
|
|
|
|||
|
|
The `32x32b.x8` instruction is already correct for multi-row softmax. No instruction change needed. The "use 16x256b.x1" guess from earlier was WRONG — that instruction reads 16 rows with 8 FP32 per row (4 FP32 per lane for 2 rows), which is more complex to use and doesn't match the S tensor layout.
|
|||
|
|
|
|||
|
|
The `32x32b.x8` reads 8 KV positions for 32 rows per call — perfect for row-wise softmax where we need to compute (max, exp, sum) per row across all KV positions.
|
|||
|
|
|
|||
|
|
### Verified Results
|
|||
|
|
|
|||
|
|
All 72 configs pass in the multi-tile kernel:
|
|||
|
|
- HD=64/128/256/512 × T=1/4/32/128 × s_k=128/256/384/512
|
|||
|
|
- Cos ≥ 0.999996 across all configs
|