From bd2da14ca6a9ad830f3d07990d774bd4afcac2c1 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 06:43:01 +0000 Subject: [PATCH] D1.2: TMEM budget verified on B200. Split-PV mandatory at hd=512 (MMA max N=256) --- STAGE_D.md | 83 ++++++++++++++++-------------------------------------- 1 file changed, 24 insertions(+), 59 deletions(-) diff --git a/STAGE_D.md b/STAGE_D.md index f458bb8f..c6cde78f 100644 --- a/STAGE_D.md +++ b/STAGE_D.md @@ -79,63 +79,28 @@ Then: softmax threads write their P values through this copy → barrier → MMA --- -## TMEM Column Budget at hd=512 +## TMEM Column Budget at hd=512 — VERIFIED ON B200 (May 23, 2026) -This MUST be calculated before writing a single line of SMEM-P code. +**tcgen05 MMA has a hard limit: N ≤ 256.** At hd=512, PV MUST be split into 2 tiles of (128, 256). The MMA rejects N=512 at construction time with `OpError: expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got 512`. -**TMA tensor tensor core (TMEM) has 512 columns per CTA.** Each column is 32 bits wide. +**Measured on B200:** -At hd=64 (TMEM-P path): -- S (QK acc): 128 cols FP32 -- P (softmax output): 64 cols FP32 (= `pv_mma_tiler[2] * BF16_width / FP32_width` = 128 * 16/32 = 64... wait, let me recalculate) - - `p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width` - - pv_mma_tiler = (128, 64, 128). pv_mma_tiler[2] = 128 - - p_cols_fp32 = 128 * 16 / 32 = 64 - - P starts at offset 32 (after 32 unused cols? No, S is at 0 with 128 cols, P at offset 32 overlaps??) - - Actually: `tmem_p0_offset = 32` means P starts at TMEM col 32. But S uses cols 0-127. P at 32 means they OVERLAP. This works because S is consumed before P is written (softmax reads S, then writes P to same TMEM region). -- After P: `o_after = max(s_cols=128, p_end=32+64=96) = 128`. `tmem_o0_offset = ((128 + 31) // 32) * 32 = 128` -- O (PV acc): `find_tmem_tensor_col_offset(tOtO)` at hd=64 ≈ 128 cols FP32 -- Total: 128 (O offset) + 128 (O size) = 256 cols. Fits in 512. ✅ +| hd | s_cols | pv_n_tile | o_cols | TMEM-P total | SMEM-P total | +|---:|-------:|----------:|-------:|-------------:|-------------:| +| 64 | 128 | 64 | 64 | 192 | 64 | +| 128 | 128 | 128 | 128 | 256 | 128 | +| 256 | 128 | 256 | 256 | 384 | 256 | +| 512 | 128 | 256 | 256 | 384 | 256 | -At hd=512 (SMEM-P path): -- P is NOT in TMEM. S and O share TMEM (sequential, not concurrent). -- S (QK acc): 128 cols FP32 (same as hd=64 — QK is always (128, 128)) -- O (PV acc): at hd=512, PV is (128, 512). PV MMA C-fragment is (128, 512) FP32 = 512 cols? NO. - - `tOtO = pv_thr.make_fragment_C(pv_as)` where `pv_as = pv_thr.partition_shape_C((128, 512))` - - The C-fragment for a tcgen05 MMA with shape (128, 512) in FP32: - - M=128 → 4 warps × 32 threads = 128 rows, each thread owns 1 row - - N=512 → 512/32 = 16 TMEM columns per thread? No, tcgen05 MMA writes (32, 32) tiles. - - For (128, 512) MMA: 4 M-tiles × 16 N-tiles = 64 (32×32) subtiles - - Each subtile uses 32 TMEM columns. But they're distributed across warps. - - `find_tmem_tensor_col_offset(tOtO)` gives the actual footprint. - - **MUST PRINT THIS ON THE B200.** Do not guess. Run a shape probe. -- If O needs ~512 cols: S (128) + O (512) = 640 > 512. **DOES NOT FIT.** - - Fix options: - 1. Drop `kv_stage` from 2 to 1 — frees SMEM but loses K/V double-buffering. TMEM budget unchanged. - 2. Split O into halves: process (128, 256) PV twice, each O tile is 256 cols. S(128) + O(256) = 384 < 512. ✅ - 3. Process S and O sequentially: after softmax consumes S, O can reuse S's TMEM region. O at offset 0, 512 cols. Total = 512. ✅ But only if we don't need S anymore when writing O (true — softmax is done before PV starts per KV tile). +**P columns are always 64** (128 KV positions × BF16_width / FP32_width). Doesn't change with hd. -**Plan: SMEM-P path reuses S's TMEM for O.** After softmax reads S and writes P to SMEM, S's TMEM region (cols 0-127) is dead. PV writes O starting at col 0. O at hd=512 needs ~256-512 cols (must measure). If O fits in cols 0-511 with S gone, we're golden. +**At hd=512 (SMEM-P + split-PV):** +- O per PV tile: 256 TMEM cols. Total = 256 < 512. ✅ +- S (128 cols) consumed by softmax before PV writes O at col 0. Sequential, no overlap. +- Two PV passes needed: V[:, 0:256] and V[:, 256:512]. QK+softmax runs once per pass. +- Alternative: keep P in SMEM, run QK+softmax once, PV twice (saves QK work but needs P in SMEM between PV tiles). -**Action item: Run shape probe on B200 before coding SMEM-P at hd=512.** - -```python -# Shape probe script to run on B200: -import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.nvgpu.tcgen05 as tcgen05 -from cutlass import BFloat16, Float32, LayoutEnum - -a_major = LayoutEnum.ROW_MAJOR # adjust to match -b_major = LayoutEnum.ROW_MAJOR -pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,512), tcgen05.OperandSource.SMEM) -pv_thr = pv_mma.get_slice(0) -pv_as = pv_thr.partition_shape_C((128, 512)) -tOtO = pv_thr.make_fragment_C(pv_as) -from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset -o_cols = find_tmem_tensor_col_offset(tOtO) -print(f"hd=512 PV C-fragment: pv_as={pv_as}, tOtO.layout={tOtO.layout}, o_cols={o_cols}") -# Also print tOtO shape -print(f"tOtO shape: {cute.shape(tOtO)}") -``` +**TMEM budget is comfortable. No need to drop kv_stage or split O.** --- @@ -289,13 +254,13 @@ Already in the kernel. `head_dim` is a constructor arg. TMEM-P path works at hd= The `use_smem_p` flag exists. PV source switches between TMEM/SMEM. TMEM layout adjusts. But the register→SMEM copy is a stub that zeros sP. -#### D1.2 — TMEM Column Budget Verification 🔨 DO THIS BEFORE CODING +#### D1.2 — TMEM Column Budget Verification ✅ VERIFIED -- [ ] Run shape probe on B200: `find_tmem_tensor_col_offset(tOtO)` at hd=512 -- [ ] Print `pv_as`, `tOtO.layout`, `o_cols` at hd=128, 256, 512 -- [ ] Calculate: can S(128) and O(???) share TMEM at hd=512? -- [ ] If O > 384 cols: plan for split-PV (two (128, 256) passes) -- [ ] Document the budget numbers HERE in this file +- [x] Run shape probe on B200: `find_tmem_tensor_col_offset(tOtO)` at hd=512 +- [x] Print `pv_as`, `tOtO.layout`, `o_cols` at hd=128, 256, 512 +- [x] Calculate: can S(128) and O(???) share TMEM at hd=512? YES — SMEM-P total = 256 < 512 +- [x] At hd=512: split-PV is MANDATORY (tcgen05 MMA rejects N=512, max N=256) +- [x] Document the budget numbers HERE in this file #### D1.3 — Implement register→SMEM copy for P (THE HARD PART) @@ -309,9 +274,9 @@ The `use_smem_p` flag exists. PV source switches between TMEM/SMEM. TMEM layout - [ ] **Test:** hd=128, n=128 → test against FP32 oracle - [ ] **Test:** hd=256, n=128 → test against FP32 oracle -#### D1.4 — Multi-PV-tile for hd>256 +#### D1.4 — Multi-PV-tile for hd>256 (MANDATORY — tcgen05 max N=256) -- [ ] Add `pv_n_tile = min(head_dim, 256)` and `n_pv_tiles = head_dim // pv_n_tile` to `__init__` +- [x] Add `pv_n_tile = min(head_dim, 256)` and `n_pv_tiles = head_dim // pv_n_tile` to `__init__` - [ ] For hd=512: 2 PV tiles of (128, 256) each - [ ] Strategy: kernel processes one PV N-tile per launch. Python orchestrates the tiles. - Pass 0: V[:, 0:256] → output[:, 0:256], QK + softmax + PV for cols 0-256