Add STAGE_D2.md: Multi-query grid + head packing plan
This commit is contained in:
311
STAGE_D2.md
Normal file
311
STAGE_D2.md
Normal file
@@ -0,0 +1,311 @@
|
||||
# STAGE_D2.md — Multi-Query Grid + Head Packing
|
||||
|
||||
## ⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
|
||||
|
||||
### The Workflow (DO NOT SKIP STEPS)
|
||||
|
||||
1. **Edit code in** `~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py` — this is the ONLY file for the FMHA kernel.
|
||||
2. **Commit and push:**
|
||||
```bash
|
||||
cd ~/dev/nvfp4-megamoe-kernel
|
||||
git add -A && git commit -m "description" && git push origin master
|
||||
```
|
||||
3. **Test on B200 using the test harness:**
|
||||
```bash
|
||||
~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_<name>.py
|
||||
```
|
||||
4. **Regression check:** After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert.
|
||||
|
||||
### The Rules (BURNED INTO THIS FILE)
|
||||
|
||||
- **NEVER edit files directly on the B200.** Edit locally, commit, push, pull, test. Every time.
|
||||
- **ALWAYS use the test harness** — `fire_b200_test`, `run_test.sh`, `check_log.sh`.
|
||||
- **CuTeDSL variables defined in `if` blocks are NOT visible in other `if` blocks.** Define all variables unconditionally before any branching.
|
||||
- **Guard dead code with `const_expr`.** CuTeDSL compiles BOTH branches of Python `if` statements.
|
||||
- **PRINT THE SHAPES. ALWAYS.** Reasoning about layouts without evidence is how we waste days.
|
||||
- **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN.
|
||||
|
||||
---
|
||||
|
||||
## Goal
|
||||
|
||||
Change the FMHA kernel from single-CTA (`grid=(1,1,1)`) to multi-CTA grid that handles:
|
||||
- **Multiple query heads** (DSV4: 64 Flash, 128 Pro)
|
||||
- **Batch dimension** (multiple requests in flight)
|
||||
- **MQA**: All query heads share the same K/V (num_kv_heads=1)
|
||||
|
||||
The kernel currently processes one head, one batch, one Q tile. D2 makes it process all heads and batches in parallel.
|
||||
|
||||
---
|
||||
|
||||
## DSV4 Attention Dimensions
|
||||
|
||||
| Config | num_query_heads | num_kv_heads | head_dim | top_k | sliding_window |
|
||||
|--------|----------------:|-------------:|---------:|------:|---------------:|
|
||||
| Flash | 64 | 1 | 512 | 512 | 128 |
|
||||
| Pro | 128 | 1 | 512 | 1024 | 128 |
|
||||
|
||||
DSV4 is **MQA** — 64 or 128 query heads all share the SAME K/V. This means:
|
||||
- Q shape: `(batch, num_query_heads, T, head_dim)` — each head has its own Q
|
||||
- K/V shape: `(batch, 1, s_k, head_dim)` — shared across all heads
|
||||
- O shape: `(batch, num_query_heads, T, head_dim)` — per-head output
|
||||
|
||||
---
|
||||
|
||||
## Architecture: Two Grid Strategies
|
||||
|
||||
### Strategy A: Head-Packed M Dimension (RECOMMENDED for decode)
|
||||
|
||||
Fold the head dimension into M. Each CTA processes ALL heads' queries for its M tile.
|
||||
|
||||
```
|
||||
Q reshaped: (batch, T * num_query_heads, head_dim) — heads packed into M
|
||||
K/V: (batch, s_k, head_dim) — shared, loaded once
|
||||
|
||||
Grid: (ceil_div(T * n_h, 128), 1, batch)
|
||||
At decode T=1, n_h=128: grid = (1, 1, batch) — single CTA per batch!
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- K/V loaded once per CTA, shared across all heads in the M tile
|
||||
- Maximum arithmetic intensity (128 heads × 1 token = 128 rows per tile)
|
||||
- Simplest code change — just reshape Q and adjust grid
|
||||
- The current kernel already processes M=128 rows; at T=1+n_h=128, M=128 fits exactly
|
||||
|
||||
**Cons:**
|
||||
- At larger T, M = T * n_h gets large. T=64, n_h=128 → M=8192 → 64 CTA tiles. But each tile still shares K/V.
|
||||
- The M tile covers multiple heads' rows. The softmax operates per-row, so each row is one head's attention. This is fine — softmax already operates per-row.
|
||||
|
||||
**Key constraint:** T * n_h must be a multiple of the M tile size (128). At T=1, n_h=128 → M=128 ✅. At T=2, n_h=128 → M=256 (2 tiles). At T=3, n_h=128 → M=384 (3 tiles). Always a multiple since n_h is a multiple of 128 and M_tile=128.
|
||||
|
||||
Wait — T * n_h = 3 * 128 = 384. 384/128 = 3 tiles. Each tile has 128 rows. The first 128 rows are heads 0-127 of token 0. The next 128 rows are heads 0-127 of token 1. The third tile is heads 0-127 of token 2. This works, but the Q TMA needs to know that rows come from different (token, head) pairs. The Q tensor layout must be `(T * n_h, head_dim)` in M-major order.
|
||||
|
||||
Actually, for DSV4 decode, T is typically 1 (single token). So M = n_h = 128 (Flash: 64, Pro: 128). At T=1, Flash needs M=64 (half a tile) and Pro needs M=128 (full tile). Flash with M=64 is a problem — the MMA tile is 128 rows. We'd need to handle the edge case or pad.
|
||||
|
||||
**For now, start with Strategy B (head as grid dim) and consider Strategy A as an optimization.**
|
||||
|
||||
### Strategy B: Head as Grid Dimension (CUTLASS reference approach)
|
||||
|
||||
Each CTA handles one query head. Grid = `(num_M_tiles, num_query_heads, batch)`.
|
||||
|
||||
```
|
||||
Q shape: (batch, num_query_heads, T, head_dim) — per-head TMA
|
||||
K/V: (batch, 1, s_k, head_dim) — shared (each CTA loads its own copy)
|
||||
O shape: (batch, num_query_heads, T, head_dim) — per-head output
|
||||
|
||||
Grid: (ceil_div(T, 128), num_query_heads, batch)
|
||||
At decode T=1, n_h=128: grid = (1, 128, batch) — 128 CTAs, each with 1 Q row (padded to 128)
|
||||
```
|
||||
|
||||
**Pros:**
|
||||
- Matches the CUTLASS reference FMHA architecture
|
||||
- Each CTA is independent — no inter-CTA coordination needed
|
||||
- Tile scheduler from `fmha_helpers.py` handles work distribution
|
||||
- Simple, well-understood pattern
|
||||
|
||||
**Cons:**
|
||||
- At decode T=1, each CTA only processes 1 query row (padded to 128). Wastes 99% of MMA compute.
|
||||
- K/V loaded 128 times (once per CTA) instead of once. Wastes HBM bandwidth.
|
||||
- 128 CTAs competing for GPU resources.
|
||||
|
||||
**Despite the waste, this is the right starting point** because:
|
||||
1. It matches the CUTLASS reference — we can copy patterns
|
||||
2. It's correct — we can optimize later with Strategy A
|
||||
3. Decode T=1 with n_h=128 and M=128 means each CTA has M=1 padded to 128 — the MMA still works, just underutilized
|
||||
|
||||
---
|
||||
|
||||
## Implementation Plan
|
||||
|
||||
### Step 1: Tensor Layout Definitions
|
||||
|
||||
Define the Q, K, V, O tensor shapes for multi-head multi-batch.
|
||||
|
||||
```python
|
||||
# Q: (batch, n_h, T, head_dim) BF16
|
||||
# K: (batch, 1, s_k, head_dim) BF16 (MQA: shared KV)
|
||||
# V: (batch, 1, s_k, head_dim) BF16 (MQA: shared KV)
|
||||
# O: (batch, n_h, T, head_dim) BF16
|
||||
# LSE: (batch, n_h, T) FP32
|
||||
```
|
||||
|
||||
The CuTeDSL TMA tensors need to be 4D for Q/O (batch, heads, seq, dim) and 3D for K/V (batch, seq, dim). The head dimension is handled by the grid — each CTA gets one head.
|
||||
|
||||
**Key:** K/V have `num_kv_heads=1` in the head dimension. The TMA atom for K/V is created from the full tensor (with head dim=1), and each CTA loads from head 0.
|
||||
|
||||
### Step 2: Grid Shape Computation
|
||||
|
||||
```python
|
||||
M_tile = 128 # rows per CTA
|
||||
num_M_tiles = math.ceil(T / M_tile)
|
||||
grid = (num_M_tiles, num_query_heads, batch)
|
||||
```
|
||||
|
||||
At decode (T=1): `grid = (1, 128, batch)`.
|
||||
At prefill (T=64): `grid = (1, 128, batch)` (T=64 < 128, still 1 M tile per head).
|
||||
|
||||
### Step 3: Block Coordinate → Tensor Index Mapping
|
||||
|
||||
Inside the kernel, map `(block_idx_x, block_idx_y, block_idx_z)` to:
|
||||
- `m_tile_idx = block_idx_x` → which M tile within this head
|
||||
- `head_idx = block_idx_y` → which query head
|
||||
- `batch_idx = block_idx_z` → which batch element
|
||||
|
||||
Then Q TMA loads from `Q[batch_idx, head_idx, m_tile_idx*M_tile:(m_tile_idx+1)*M_tile, :]`.
|
||||
K/V TMA loads from `K[batch_idx, 0, :, :]` (head 0, all tokens).
|
||||
O TMA stores to `O[batch_idx, head_idx, m_tile_idx*M_tile:(m_tile_idx+1)*M_tile, :]`.
|
||||
|
||||
### Step 4: TMA Tensor Construction
|
||||
|
||||
The TMA tensors are constructed from the PyTorch tensors before launch. For multi-head, Q needs to be indexed by head. CuTeDSL's `make_tiled_tma_atom_A` creates a TMA descriptor from the tensor shape.
|
||||
|
||||
**Key question:** Can the TMA descriptor be created once and reused across heads? Yes — if Q is contiguous in the head dimension, the TMA descriptor covers all heads, and each CTA indexes into its head via the coordinate.
|
||||
|
||||
Actually, the CUTLASS reference creates the TMA descriptor from the FULL tensor shape `(batch, n_h, T, head_dim)` and the CTA's block coordinate selects the right head/tile.
|
||||
|
||||
### Step 5: Per-CTA Q Offset
|
||||
|
||||
Each CTA needs to load Q for its specific head. The TMA load already handles batch and M-tile indexing. For the head dimension, the TMA tensor has a mode for it.
|
||||
|
||||
In the kernel:
|
||||
```python
|
||||
bidx, bidy, bidz = cute.arch.block_idx()
|
||||
# Q tile: m_tile=bidx, head=bidy, batch=bidz
|
||||
```
|
||||
|
||||
The TMA copy: `cute.copy(tma_q, mQ[(None, bidy, bidx, bidz)], sQ[(None, qh.index)])` — but this depends on the TMA tensor layout.
|
||||
|
||||
### Step 6: K/V Shared Across Heads (MQA Optimization)
|
||||
|
||||
For MQA, all CTAs in the head dimension (same `block_idx_x`, same `block_idx_z`, different `block_idx_y`) load the same K/V. Two approaches:
|
||||
|
||||
1. **Independent loads (simple):** Each CTA loads its own K/V. Wastes bandwidth but correct.
|
||||
2. **Cluster-wide load (optimized):** Use `cluster_shape_mn = (1, num_query_heads, 1)` so all heads in a cluster share the same K/V SMEM. Requires cluster barriers.
|
||||
|
||||
**Start with independent loads.** The K/V are small at decode (s_k=128 or 512, head_dim=512). Even 128× loads of 128×512 BF16 = 16MB is fine.
|
||||
|
||||
### Step 7: Output TMA Store
|
||||
|
||||
O has shape `(batch, n_h, T, head_dim)`. Each CTA writes its head's output. The TMA store uses the same block coordinate mapping.
|
||||
|
||||
---
|
||||
|
||||
## To-Do List
|
||||
|
||||
- [ ] **D2.1:** Define multi-head tensor shapes in `FmhaKernel.__init__`
|
||||
- Add `num_query_heads` constructor parameter
|
||||
- Add `batch_size` parameter (or infer from tensor shapes at launch time)
|
||||
- Keep `num_kv_heads=1` hardcoded for now (MQA)
|
||||
|
||||
- [ ] **D2.2:** Create test file `tests/unit/test_d2_multihead.py`
|
||||
- Test with n_h=2, batch=1, T=1, hd=64 (minimal multi-head)
|
||||
- Test with n_h=8, batch=2, T=1, hd=64 (small batch)
|
||||
- Reference: PyTorch SDPA with MQA (all heads share K/V)
|
||||
- Start with single-head regression (n_h=1, batch=1) to verify nothing breaks
|
||||
|
||||
- [ ] **D2.3:** Modify `__call__` to construct multi-head TMA descriptors
|
||||
- Q tensor: `(batch, n_h, T, head_dim)` — TMA over (batch, n_h, T, head_dim)
|
||||
- K/V tensors: `(batch, 1, s_k, head_dim)` — TMA over (batch, s_k, head_dim) (squeeze kv_heads)
|
||||
- O tensor: `(batch, n_h, T, head_dim)` — TMA over (batch, n_h, T, head_dim)
|
||||
- LSE tensor: `(batch, n_h, T)` — per-head LSE
|
||||
|
||||
- [ ] **D2.4:** Compute grid shape and pass to launch
|
||||
- `grid = (ceil_div(T, 128), num_query_heads, batch)`
|
||||
- The `cta_tiler` remains `(128, 128, head_dim)` for QK, `(128, pv_n_tile, head_dim)` for PV
|
||||
|
||||
- [ ] **D2.5:** Add block coordinate mapping inside `_kernel`
|
||||
- `bidx, bidy, bidz = cute.arch.block_idx()`
|
||||
- Map to `m_tile_idx, head_idx, batch_idx`
|
||||
- Use head_idx to index the Q TMA tensor's head mode
|
||||
- Use batch_idx for batch mode
|
||||
|
||||
- [ ] **D2.6:** Adjust TMA loads for per-head Q and shared K/V
|
||||
- Q load: index by `(batch_idx, head_idx, m_tile_idx, ...)`
|
||||
- K/V load: index by `(batch_idx, ...)` (no head index)
|
||||
- O store: index by `(batch_idx, head_idx, m_tile_idx, ...)`
|
||||
|
||||
- [ ] **D2.7:** Test multi-head correctness
|
||||
- n_h=2, batch=1, T=1, hd=64 → cos ≥0.999 per head
|
||||
- n_h=8, batch=2, T=1, hd=64 → cos ≥0.999 per head
|
||||
- n_h=64, batch=1, T=1, hd=64 → Flash decode config
|
||||
|
||||
- [ ] **D2.8:** Test at hd=128, hd=256
|
||||
- Same tests as D2.7 but with larger head dims
|
||||
- Verify SMEM budget still fits (no per-head SMEM change — each CTA has its own SMEM)
|
||||
|
||||
- [ ] **D2.9:** LSE output for multi-head
|
||||
- LSE shape: `(batch, n_h, T)` or `(batch, n_h, ceil_div(T, 128) * 128)` padded
|
||||
- Each CTA writes its head's LSE to the correct position
|
||||
- Needed for D5 merge
|
||||
|
||||
---
|
||||
|
||||
## Key Technical Decisions
|
||||
|
||||
### Q Tensor Layout for TMA
|
||||
|
||||
The Q tensor must be laid out so the TMA can efficiently load per-head tiles. Two options:
|
||||
|
||||
**Option 1: (batch, n_h, T, head_dim) — head is a TMA mode**
|
||||
- TMA descriptor covers all heads
|
||||
- Each CTA selects its head via the block coordinate
|
||||
- This is how the CUTLASS reference works: `o_shape = (s, d, ((h_r, h_k), b))`
|
||||
|
||||
**Option 2: (batch, T * n_h, head_dim) — heads packed into M**
|
||||
- TMA descriptor is 2D (batch, M, head_dim)
|
||||
- Each CTA selects its M tile
|
||||
- No head mode in TMA — simpler but can't distinguish heads in the kernel
|
||||
|
||||
**We'll use Option 1** because it matches the CUTLASS reference and allows per-head LSE output.
|
||||
|
||||
### Grid vs. Persistent Tile Scheduler
|
||||
|
||||
The CUTLASS reference uses a persistent tile scheduler for load balancing. For DSV4 decode (T=1), each CTA has exactly one tile, so a static grid suffices. For prefill with variable sequence lengths, a persistent scheduler would help. **Start with static grid, add persistent scheduler later if needed.**
|
||||
|
||||
### MQA K/V Sharing
|
||||
|
||||
At decode T=1, n_h=128: 128 CTAs each load the same K/V. Each K/V tile is s_k × head_dim BF16 = 128 × 512 × 2 = 128KB. Total K/V read = 128 × 128KB = 16MB. HBM bandwidth on B200 is ~8TB/s, so 16MB takes ~2μs. The MMA compute per CTA is ~5μs. So K/V redundancy is not a bottleneck at decode.
|
||||
|
||||
At prefill T=64, s_k=1024: K/V = 1024 × 512 × 2 = 1MB per CTA. 128 CTAs × 1MB = 128MB. Still ~16μs, comparable to compute. Not a bottleneck.
|
||||
|
||||
**Start with independent K/V loads. Optimize later with cluster-wide sharing if profiling shows it's needed.**
|
||||
|
||||
---
|
||||
|
||||
## Risks and Mitigations
|
||||
|
||||
| Risk | Mitigation |
|
||||
|------|-----------|
|
||||
| TMA tensor layout mismatch for multi-head Q | Print shapes at trace time. Start with n_h=1 regression. |
|
||||
| Grid dimension confusion (which axis is heads) | Follow CUTLASS reference: (M_tiles, heads, batch) |
|
||||
| LSE output indexing for multi-head | Each CTA writes LSE to its (batch, head) position |
|
||||
| SMEM budget unchanged per CTA | Each CTA has independent SMEM. No budget change. |
|
||||
| Edge case: T not divisible by M_tile | For decode T=1, M=1 padded to 128. Need to mask output rows > T. |
|
||||
|
||||
---
|
||||
|
||||
## CUTLASS Reference FMHA Architecture (for reference)
|
||||
|
||||
The CUTLASS Blackwell FMHA (`fmha.py`) uses:
|
||||
- **Grid:** `(num_M_tiles, num_q_heads, batch)` computed via `fmha_utils.compute_grid`
|
||||
- **TMA shapes:** Q/K/V/O as `(seq, dim, ((h_r, h_k), batch))`
|
||||
- **Tile scheduler:** `FmhaStaticTileScheduler` with persistent mode
|
||||
- **12-warp layout:** Separate TMA, MMA, softmax, correction, epilogue warps
|
||||
- **Per-CTA:** One head's M tiles. GQA support via `h_r` and `h_k` in the grid.
|
||||
|
||||
Key files:
|
||||
- `/root/nvidia-meeting/venv/lib/python3.10/site-packages/flashinfer/data/cutlass/examples/python/CuTeDSL/blackwell/fmha.py` — main kernel
|
||||
- `/root/nvidia-meeting/venv/lib/python3.10/site-packages/flashinfer/data/cutlass/examples/python/CuTeDSL/helpers/fmha_helpers.py` — grid/scheduler
|
||||
|
||||
Our kernel uses a simpler 6-warp layout and no persistent scheduler. We'll add the multi-head grid on top of our existing architecture.
|
||||
|
||||
---
|
||||
|
||||
## Dependencies
|
||||
|
||||
- **D1 (parameterized HEAD_DIM):** ✅ DONE for hd≤256. hd=512 blocked by MLIR compilation but not needed for D2 development.
|
||||
- **D5a (un-normalized O + LSE):** ✅ DONE. Needed for D5 merge but not for D2.
|
||||
|
||||
## Next Stage After D2
|
||||
|
||||
D3 — SWA sequence length mask. Once we have multi-head multi-batch, add the SWA window mask (`swa_lens` per batch).
|
||||
Reference in New Issue
Block a user