From bcfe52df0d0023efdb73acb40ee426bf56bc93b2 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 21:43:04 +0000 Subject: [PATCH] Add STAGE_D2.md: Multi-query grid + head packing plan --- STAGE_D2.md | 311 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 STAGE_D2.md diff --git a/STAGE_D2.md b/STAGE_D2.md new file mode 100644 index 00000000..08beaa42 --- /dev/null +++ b/STAGE_D2.md @@ -0,0 +1,311 @@ +# STAGE_D2.md — Multi-Query Grid + Head Packing + +## ⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING + +### The Workflow (DO NOT SKIP STEPS) + +1. **Edit code in** `~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py` — this is the ONLY file for the FMHA kernel. +2. **Commit and push:** + ```bash + cd ~/dev/nvfp4-megamoe-kernel + git add -A && git commit -m "description" && git push origin master + ``` +3. **Test on B200 using the test harness:** + ```bash + ~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_.py + ``` +4. **Regression check:** After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert. + +### The Rules (BURNED INTO THIS FILE) + +- **NEVER edit files directly on the B200.** Edit locally, commit, push, pull, test. Every time. +- **ALWAYS use the test harness** — `fire_b200_test`, `run_test.sh`, `check_log.sh`. +- **CuTeDSL variables defined in `if` blocks are NOT visible in other `if` blocks.** Define all variables unconditionally before any branching. +- **Guard dead code with `const_expr`.** CuTeDSL compiles BOTH branches of Python `if` statements. +- **PRINT THE SHAPES. ALWAYS.** Reasoning about layouts without evidence is how we waste days. +- **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN. + +--- + +## Goal + +Change the FMHA kernel from single-CTA (`grid=(1,1,1)`) to multi-CTA grid that handles: +- **Multiple query heads** (DSV4: 64 Flash, 128 Pro) +- **Batch dimension** (multiple requests in flight) +- **MQA**: All query heads share the same K/V (num_kv_heads=1) + +The kernel currently processes one head, one batch, one Q tile. D2 makes it process all heads and batches in parallel. + +--- + +## DSV4 Attention Dimensions + +| Config | num_query_heads | num_kv_heads | head_dim | top_k | sliding_window | +|--------|----------------:|-------------:|---------:|------:|---------------:| +| Flash | 64 | 1 | 512 | 512 | 128 | +| Pro | 128 | 1 | 512 | 1024 | 128 | + +DSV4 is **MQA** — 64 or 128 query heads all share the SAME K/V. This means: +- Q shape: `(batch, num_query_heads, T, head_dim)` — each head has its own Q +- K/V shape: `(batch, 1, s_k, head_dim)` — shared across all heads +- O shape: `(batch, num_query_heads, T, head_dim)` — per-head output + +--- + +## Architecture: Two Grid Strategies + +### Strategy A: Head-Packed M Dimension (RECOMMENDED for decode) + +Fold the head dimension into M. Each CTA processes ALL heads' queries for its M tile. + +``` +Q reshaped: (batch, T * num_query_heads, head_dim) — heads packed into M +K/V: (batch, s_k, head_dim) — shared, loaded once + +Grid: (ceil_div(T * n_h, 128), 1, batch) +At decode T=1, n_h=128: grid = (1, 1, batch) — single CTA per batch! +``` + +**Pros:** +- K/V loaded once per CTA, shared across all heads in the M tile +- Maximum arithmetic intensity (128 heads × 1 token = 128 rows per tile) +- Simplest code change — just reshape Q and adjust grid +- The current kernel already processes M=128 rows; at T=1+n_h=128, M=128 fits exactly + +**Cons:** +- At larger T, M = T * n_h gets large. T=64, n_h=128 → M=8192 → 64 CTA tiles. But each tile still shares K/V. +- The M tile covers multiple heads' rows. The softmax operates per-row, so each row is one head's attention. This is fine — softmax already operates per-row. + +**Key constraint:** T * n_h must be a multiple of the M tile size (128). At T=1, n_h=128 → M=128 ✅. At T=2, n_h=128 → M=256 (2 tiles). At T=3, n_h=128 → M=384 (3 tiles). Always a multiple since n_h is a multiple of 128 and M_tile=128. + +Wait — T * n_h = 3 * 128 = 384. 384/128 = 3 tiles. Each tile has 128 rows. The first 128 rows are heads 0-127 of token 0. The next 128 rows are heads 0-127 of token 1. The third tile is heads 0-127 of token 2. This works, but the Q TMA needs to know that rows come from different (token, head) pairs. The Q tensor layout must be `(T * n_h, head_dim)` in M-major order. + +Actually, for DSV4 decode, T is typically 1 (single token). So M = n_h = 128 (Flash: 64, Pro: 128). At T=1, Flash needs M=64 (half a tile) and Pro needs M=128 (full tile). Flash with M=64 is a problem — the MMA tile is 128 rows. We'd need to handle the edge case or pad. + +**For now, start with Strategy B (head as grid dim) and consider Strategy A as an optimization.** + +### Strategy B: Head as Grid Dimension (CUTLASS reference approach) + +Each CTA handles one query head. Grid = `(num_M_tiles, num_query_heads, batch)`. + +``` +Q shape: (batch, num_query_heads, T, head_dim) — per-head TMA +K/V: (batch, 1, s_k, head_dim) — shared (each CTA loads its own copy) +O shape: (batch, num_query_heads, T, head_dim) — per-head output + +Grid: (ceil_div(T, 128), num_query_heads, batch) +At decode T=1, n_h=128: grid = (1, 128, batch) — 128 CTAs, each with 1 Q row (padded to 128) +``` + +**Pros:** +- Matches the CUTLASS reference FMHA architecture +- Each CTA is independent — no inter-CTA coordination needed +- Tile scheduler from `fmha_helpers.py` handles work distribution +- Simple, well-understood pattern + +**Cons:** +- At decode T=1, each CTA only processes 1 query row (padded to 128). Wastes 99% of MMA compute. +- K/V loaded 128 times (once per CTA) instead of once. Wastes HBM bandwidth. +- 128 CTAs competing for GPU resources. + +**Despite the waste, this is the right starting point** because: +1. It matches the CUTLASS reference — we can copy patterns +2. It's correct — we can optimize later with Strategy A +3. Decode T=1 with n_h=128 and M=128 means each CTA has M=1 padded to 128 — the MMA still works, just underutilized + +--- + +## Implementation Plan + +### Step 1: Tensor Layout Definitions + +Define the Q, K, V, O tensor shapes for multi-head multi-batch. + +```python +# Q: (batch, n_h, T, head_dim) BF16 +# K: (batch, 1, s_k, head_dim) BF16 (MQA: shared KV) +# V: (batch, 1, s_k, head_dim) BF16 (MQA: shared KV) +# O: (batch, n_h, T, head_dim) BF16 +# LSE: (batch, n_h, T) FP32 +``` + +The CuTeDSL TMA tensors need to be 4D for Q/O (batch, heads, seq, dim) and 3D for K/V (batch, seq, dim). The head dimension is handled by the grid — each CTA gets one head. + +**Key:** K/V have `num_kv_heads=1` in the head dimension. The TMA atom for K/V is created from the full tensor (with head dim=1), and each CTA loads from head 0. + +### Step 2: Grid Shape Computation + +```python +M_tile = 128 # rows per CTA +num_M_tiles = math.ceil(T / M_tile) +grid = (num_M_tiles, num_query_heads, batch) +``` + +At decode (T=1): `grid = (1, 128, batch)`. +At prefill (T=64): `grid = (1, 128, batch)` (T=64 < 128, still 1 M tile per head). + +### Step 3: Block Coordinate → Tensor Index Mapping + +Inside the kernel, map `(block_idx_x, block_idx_y, block_idx_z)` to: +- `m_tile_idx = block_idx_x` → which M tile within this head +- `head_idx = block_idx_y` → which query head +- `batch_idx = block_idx_z` → which batch element + +Then Q TMA loads from `Q[batch_idx, head_idx, m_tile_idx*M_tile:(m_tile_idx+1)*M_tile, :]`. +K/V TMA loads from `K[batch_idx, 0, :, :]` (head 0, all tokens). +O TMA stores to `O[batch_idx, head_idx, m_tile_idx*M_tile:(m_tile_idx+1)*M_tile, :]`. + +### Step 4: TMA Tensor Construction + +The TMA tensors are constructed from the PyTorch tensors before launch. For multi-head, Q needs to be indexed by head. CuTeDSL's `make_tiled_tma_atom_A` creates a TMA descriptor from the tensor shape. + +**Key question:** Can the TMA descriptor be created once and reused across heads? Yes — if Q is contiguous in the head dimension, the TMA descriptor covers all heads, and each CTA indexes into its head via the coordinate. + +Actually, the CUTLASS reference creates the TMA descriptor from the FULL tensor shape `(batch, n_h, T, head_dim)` and the CTA's block coordinate selects the right head/tile. + +### Step 5: Per-CTA Q Offset + +Each CTA needs to load Q for its specific head. The TMA load already handles batch and M-tile indexing. For the head dimension, the TMA tensor has a mode for it. + +In the kernel: +```python +bidx, bidy, bidz = cute.arch.block_idx() +# Q tile: m_tile=bidx, head=bidy, batch=bidz +``` + +The TMA copy: `cute.copy(tma_q, mQ[(None, bidy, bidx, bidz)], sQ[(None, qh.index)])` — but this depends on the TMA tensor layout. + +### Step 6: K/V Shared Across Heads (MQA Optimization) + +For MQA, all CTAs in the head dimension (same `block_idx_x`, same `block_idx_z`, different `block_idx_y`) load the same K/V. Two approaches: + +1. **Independent loads (simple):** Each CTA loads its own K/V. Wastes bandwidth but correct. +2. **Cluster-wide load (optimized):** Use `cluster_shape_mn = (1, num_query_heads, 1)` so all heads in a cluster share the same K/V SMEM. Requires cluster barriers. + +**Start with independent loads.** The K/V are small at decode (s_k=128 or 512, head_dim=512). Even 128× loads of 128×512 BF16 = 16MB is fine. + +### Step 7: Output TMA Store + +O has shape `(batch, n_h, T, head_dim)`. Each CTA writes its head's output. The TMA store uses the same block coordinate mapping. + +--- + +## To-Do List + +- [ ] **D2.1:** Define multi-head tensor shapes in `FmhaKernel.__init__` + - Add `num_query_heads` constructor parameter + - Add `batch_size` parameter (or infer from tensor shapes at launch time) + - Keep `num_kv_heads=1` hardcoded for now (MQA) + +- [ ] **D2.2:** Create test file `tests/unit/test_d2_multihead.py` + - Test with n_h=2, batch=1, T=1, hd=64 (minimal multi-head) + - Test with n_h=8, batch=2, T=1, hd=64 (small batch) + - Reference: PyTorch SDPA with MQA (all heads share K/V) + - Start with single-head regression (n_h=1, batch=1) to verify nothing breaks + +- [ ] **D2.3:** Modify `__call__` to construct multi-head TMA descriptors + - Q tensor: `(batch, n_h, T, head_dim)` — TMA over (batch, n_h, T, head_dim) + - K/V tensors: `(batch, 1, s_k, head_dim)` — TMA over (batch, s_k, head_dim) (squeeze kv_heads) + - O tensor: `(batch, n_h, T, head_dim)` — TMA over (batch, n_h, T, head_dim) + - LSE tensor: `(batch, n_h, T)` — per-head LSE + +- [ ] **D2.4:** Compute grid shape and pass to launch + - `grid = (ceil_div(T, 128), num_query_heads, batch)` + - The `cta_tiler` remains `(128, 128, head_dim)` for QK, `(128, pv_n_tile, head_dim)` for PV + +- [ ] **D2.5:** Add block coordinate mapping inside `_kernel` + - `bidx, bidy, bidz = cute.arch.block_idx()` + - Map to `m_tile_idx, head_idx, batch_idx` + - Use head_idx to index the Q TMA tensor's head mode + - Use batch_idx for batch mode + +- [ ] **D2.6:** Adjust TMA loads for per-head Q and shared K/V + - Q load: index by `(batch_idx, head_idx, m_tile_idx, ...)` + - K/V load: index by `(batch_idx, ...)` (no head index) + - O store: index by `(batch_idx, head_idx, m_tile_idx, ...)` + +- [ ] **D2.7:** Test multi-head correctness + - n_h=2, batch=1, T=1, hd=64 → cos ≥0.999 per head + - n_h=8, batch=2, T=1, hd=64 → cos ≥0.999 per head + - n_h=64, batch=1, T=1, hd=64 → Flash decode config + +- [ ] **D2.8:** Test at hd=128, hd=256 + - Same tests as D2.7 but with larger head dims + - Verify SMEM budget still fits (no per-head SMEM change — each CTA has its own SMEM) + +- [ ] **D2.9:** LSE output for multi-head + - LSE shape: `(batch, n_h, T)` or `(batch, n_h, ceil_div(T, 128) * 128)` padded + - Each CTA writes its head's LSE to the correct position + - Needed for D5 merge + +--- + +## Key Technical Decisions + +### Q Tensor Layout for TMA + +The Q tensor must be laid out so the TMA can efficiently load per-head tiles. Two options: + +**Option 1: (batch, n_h, T, head_dim) — head is a TMA mode** +- TMA descriptor covers all heads +- Each CTA selects its head via the block coordinate +- This is how the CUTLASS reference works: `o_shape = (s, d, ((h_r, h_k), b))` + +**Option 2: (batch, T * n_h, head_dim) — heads packed into M** +- TMA descriptor is 2D (batch, M, head_dim) +- Each CTA selects its M tile +- No head mode in TMA — simpler but can't distinguish heads in the kernel + +**We'll use Option 1** because it matches the CUTLASS reference and allows per-head LSE output. + +### Grid vs. Persistent Tile Scheduler + +The CUTLASS reference uses a persistent tile scheduler for load balancing. For DSV4 decode (T=1), each CTA has exactly one tile, so a static grid suffices. For prefill with variable sequence lengths, a persistent scheduler would help. **Start with static grid, add persistent scheduler later if needed.** + +### MQA K/V Sharing + +At decode T=1, n_h=128: 128 CTAs each load the same K/V. Each K/V tile is s_k × head_dim BF16 = 128 × 512 × 2 = 128KB. Total K/V read = 128 × 128KB = 16MB. HBM bandwidth on B200 is ~8TB/s, so 16MB takes ~2μs. The MMA compute per CTA is ~5μs. So K/V redundancy is not a bottleneck at decode. + +At prefill T=64, s_k=1024: K/V = 1024 × 512 × 2 = 1MB per CTA. 128 CTAs × 1MB = 128MB. Still ~16μs, comparable to compute. Not a bottleneck. + +**Start with independent K/V loads. Optimize later with cluster-wide sharing if profiling shows it's needed.** + +--- + +## Risks and Mitigations + +| Risk | Mitigation | +|------|-----------| +| TMA tensor layout mismatch for multi-head Q | Print shapes at trace time. Start with n_h=1 regression. | +| Grid dimension confusion (which axis is heads) | Follow CUTLASS reference: (M_tiles, heads, batch) | +| LSE output indexing for multi-head | Each CTA writes LSE to its (batch, head) position | +| SMEM budget unchanged per CTA | Each CTA has independent SMEM. No budget change. | +| Edge case: T not divisible by M_tile | For decode T=1, M=1 padded to 128. Need to mask output rows > T. | + +--- + +## CUTLASS Reference FMHA Architecture (for reference) + +The CUTLASS Blackwell FMHA (`fmha.py`) uses: +- **Grid:** `(num_M_tiles, num_q_heads, batch)` computed via `fmha_utils.compute_grid` +- **TMA shapes:** Q/K/V/O as `(seq, dim, ((h_r, h_k), batch))` +- **Tile scheduler:** `FmhaStaticTileScheduler` with persistent mode +- **12-warp layout:** Separate TMA, MMA, softmax, correction, epilogue warps +- **Per-CTA:** One head's M tiles. GQA support via `h_r` and `h_k` in the grid. + +Key files: +- `/root/nvidia-meeting/venv/lib/python3.10/site-packages/flashinfer/data/cutlass/examples/python/CuTeDSL/blackwell/fmha.py` — main kernel +- `/root/nvidia-meeting/venv/lib/python3.10/site-packages/flashinfer/data/cutlass/examples/python/CuTeDSL/helpers/fmha_helpers.py` — grid/scheduler + +Our kernel uses a simpler 6-warp layout and no persistent scheduler. We'll add the multi-head grid on top of our existing architecture. + +--- + +## Dependencies + +- **D1 (parameterized HEAD_DIM):** ✅ DONE for hd≤256. hd=512 blocked by MLIR compilation but not needed for D2 development. +- **D5a (un-normalized O + LSE):** ✅ DONE. Needed for D5 merge but not for D2. + +## Next Stage After D2 + +D3 — SWA sequence length mask. Once we have multi-head multi-batch, add the SWA window mask (`swa_lens` per batch).