D1.2: TMEM budget verified on B200. Split-PV mandatory at hd=512 (MMA max N=256)

This commit is contained in:
2026-05-23 06:43:01 +00:00
parent a951a95276
commit bd2da14ca6

View File

@@ -79,63 +79,28 @@ Then: softmax threads write their P values through this copy → barrier → MMA
---
## TMEM Column Budget at hd=512
## TMEM Column Budget at hd=512 — VERIFIED ON B200 (May 23, 2026)
This MUST be calculated before writing a single line of SMEM-P code.
**tcgen05 MMA has a hard limit: N ≤ 256.** At hd=512, PV MUST be split into 2 tiles of (128, 256). The MMA rejects N=512 at construction time with `OpError: expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got 512`.
**TMA tensor tensor core (TMEM) has 512 columns per CTA.** Each column is 32 bits wide.
**Measured on B200:**
At hd=64 (TMEM-P path):
- S (QK acc): 128 cols FP32
- P (softmax output): 64 cols FP32 (= `pv_mma_tiler[2] * BF16_width / FP32_width` = 128 * 16/32 = 64... wait, let me recalculate)
- `p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width`
- pv_mma_tiler = (128, 64, 128). pv_mma_tiler[2] = 128
- p_cols_fp32 = 128 * 16 / 32 = 64
- P starts at offset 32 (after 32 unused cols? No, S is at 0 with 128 cols, P at offset 32 overlaps??)
- Actually: `tmem_p0_offset = 32` means P starts at TMEM col 32. But S uses cols 0-127. P at 32 means they OVERLAP. This works because S is consumed before P is written (softmax reads S, then writes P to same TMEM region).
- After P: `o_after = max(s_cols=128, p_end=32+64=96) = 128`. `tmem_o0_offset = ((128 + 31) // 32) * 32 = 128`
- O (PV acc): `find_tmem_tensor_col_offset(tOtO)` at hd=64 ≈ 128 cols FP32
- Total: 128 (O offset) + 128 (O size) = 256 cols. Fits in 512. ✅
| hd | s_cols | pv_n_tile | o_cols | TMEM-P total | SMEM-P total |
|---:|-------:|----------:|-------:|-------------:|-------------:|
| 64 | 128 | 64 | 64 | 192 | 64 |
| 128 | 128 | 128 | 128 | 256 | 128 |
| 256 | 128 | 256 | 256 | 384 | 256 |
| 512 | 128 | 256 | 256 | 384 | 256 |
At hd=512 (SMEM-P path):
- P is NOT in TMEM. S and O share TMEM (sequential, not concurrent).
- S (QK acc): 128 cols FP32 (same as hd=64 — QK is always (128, 128))
- O (PV acc): at hd=512, PV is (128, 512). PV MMA C-fragment is (128, 512) FP32 = 512 cols? NO.
- `tOtO = pv_thr.make_fragment_C(pv_as)` where `pv_as = pv_thr.partition_shape_C((128, 512))`
- The C-fragment for a tcgen05 MMA with shape (128, 512) in FP32:
- M=128 → 4 warps × 32 threads = 128 rows, each thread owns 1 row
- N=512 → 512/32 = 16 TMEM columns per thread? No, tcgen05 MMA writes (32, 32) tiles.
- For (128, 512) MMA: 4 M-tiles × 16 N-tiles = 64 (32×32) subtiles
- Each subtile uses 32 TMEM columns. But they're distributed across warps.
- `find_tmem_tensor_col_offset(tOtO)` gives the actual footprint.
- **MUST PRINT THIS ON THE B200.** Do not guess. Run a shape probe.
- If O needs ~512 cols: S (128) + O (512) = 640 > 512. **DOES NOT FIT.**
- Fix options:
1. Drop `kv_stage` from 2 to 1 — frees SMEM but loses K/V double-buffering. TMEM budget unchanged.
2. Split O into halves: process (128, 256) PV twice, each O tile is 256 cols. S(128) + O(256) = 384 < 512. ✅
3. Process S and O sequentially: after softmax consumes S, O can reuse S's TMEM region. O at offset 0, 512 cols. Total = 512. ✅ But only if we don't need S anymore when writing O (true — softmax is done before PV starts per KV tile).
**P columns are always 64** (128 KV positions × BF16_width / FP32_width). Doesn't change with hd.
**Plan: SMEM-P path reuses S's TMEM for O.** After softmax reads S and writes P to SMEM, S's TMEM region (cols 0-127) is dead. PV writes O starting at col 0. O at hd=512 needs ~256-512 cols (must measure). If O fits in cols 0-511 with S gone, we're golden.
**At hd=512 (SMEM-P + split-PV):**
- O per PV tile: 256 TMEM cols. Total = 256 < 512. ✅
- S (128 cols) consumed by softmax before PV writes O at col 0. Sequential, no overlap.
- Two PV passes needed: V[:, 0:256] and V[:, 256:512]. QK+softmax runs once per pass.
- Alternative: keep P in SMEM, run QK+softmax once, PV twice (saves QK work but needs P in SMEM between PV tiles).
**Action item: Run shape probe on B200 before coding SMEM-P at hd=512.**
```python
# Shape probe script to run on B200:
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.nvgpu.tcgen05 as tcgen05
from cutlass import BFloat16, Float32, LayoutEnum
a_major = LayoutEnum.ROW_MAJOR # adjust to match
b_major = LayoutEnum.ROW_MAJOR
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,512), tcgen05.OperandSource.SMEM)
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C((128, 512))
tOtO = pv_thr.make_fragment_C(pv_as)
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
o_cols = find_tmem_tensor_col_offset(tOtO)
print(f"hd=512 PV C-fragment: pv_as={pv_as}, tOtO.layout={tOtO.layout}, o_cols={o_cols}")
# Also print tOtO shape
print(f"tOtO shape: {cute.shape(tOtO)}")
```
**TMEM budget is comfortable. No need to drop kv_stage or split O.**
---
@@ -289,13 +254,13 @@ Already in the kernel. `head_dim` is a constructor arg. TMEM-P path works at hd=
The `use_smem_p` flag exists. PV source switches between TMEM/SMEM. TMEM layout adjusts. But the register→SMEM copy is a stub that zeros sP.
#### D1.2 — TMEM Column Budget Verification 🔨 DO THIS BEFORE CODING
#### D1.2 — TMEM Column Budget Verification ✅ VERIFIED
- [ ] Run shape probe on B200: `find_tmem_tensor_col_offset(tOtO)` at hd=512
- [ ] Print `pv_as`, `tOtO.layout`, `o_cols` at hd=128, 256, 512
- [ ] Calculate: can S(128) and O(???) share TMEM at hd=512?
- [ ] If O > 384 cols: plan for split-PV (two (128, 256) passes)
- [ ] Document the budget numbers HERE in this file
- [x] Run shape probe on B200: `find_tmem_tensor_col_offset(tOtO)` at hd=512
- [x] Print `pv_as`, `tOtO.layout`, `o_cols` at hd=128, 256, 512
- [x] Calculate: can S(128) and O(???) share TMEM at hd=512? YES — SMEM-P total = 256 < 512
- [x] At hd=512: split-PV is MANDATORY (tcgen05 MMA rejects N=512, max N=256)
- [x] Document the budget numbers HERE in this file
#### D1.3 — Implement register→SMEM copy for P (THE HARD PART)
@@ -309,9 +274,9 @@ The `use_smem_p` flag exists. PV source switches between TMEM/SMEM. TMEM layout
- [ ] **Test:** hd=128, n=128 → test against FP32 oracle
- [ ] **Test:** hd=256, n=128 → test against FP32 oracle
#### D1.4 — Multi-PV-tile for hd>256
#### D1.4 — Multi-PV-tile for hd>256 (MANDATORY — tcgen05 max N=256)
- [ ] Add `pv_n_tile = min(head_dim, 256)` and `n_pv_tiles = head_dim // pv_n_tile` to `__init__`
- [x] Add `pv_n_tile = min(head_dim, 256)` and `n_pv_tiles = head_dim // pv_n_tile` to `__init__`
- [ ] For hd=512: 2 PV tiles of (128, 256) each
- [ ] Strategy: kernel processes one PV N-tile per launch. Python orchestrates the tiles.
- Pass 0: V[:, 0:256] → output[:, 0:256], QK + softmax + PV for cols 0-256