14 KiB
STAGE_D2.md — Multi-Query Grid + Head Packing
⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
The Workflow (DO NOT SKIP STEPS)
- Edit code in
~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py— this is the ONLY file for the FMHA kernel. - Commit and push:
cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master - Test on B200 using the test harness:
~/.openclaw/workspace/fire_b200_test tests/unit/test_d2_<name>.py - Regression check: After every change, verify hd=64 cos ~0.999998 still matches. If it doesn't, the change is WRONG. Revert.
The Rules (BURNED INTO THIS FILE)
- NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
- ALWAYS use the test harness —
fire_b200_test,run_test.sh,check_log.sh. - CuTeDSL variables defined in
ifblocks are NOT visible in otherifblocks. Define all variables unconditionally before any branching. - Guard dead code with
const_expr. CuTeDSL compiles BOTH branches of Pythonifstatements. - PRINT THE SHAPES. ALWAYS. Reasoning about layouts without evidence is how we waste days.
- After every P store to TMEM, call
cute.arch.fence_view_async_tmem_store(). Missing this produces NaN.
Goal
Change the FMHA kernel from single-CTA (grid=(1,1,1)) to multi-CTA grid that handles:
- Multiple query heads (DSV4: 64 Flash, 128 Pro)
- Batch dimension (multiple requests in flight)
- MQA: All query heads share the same K/V (num_kv_heads=1)
The kernel currently processes one head, one batch, one Q tile. D2 makes it process all heads and batches in parallel.
DSV4 Attention Dimensions
| Config | num_query_heads | num_kv_heads | head_dim | top_k | sliding_window |
|---|---|---|---|---|---|
| Flash | 64 | 1 | 512 | 512 | 128 |
| Pro | 128 | 1 | 512 | 1024 | 128 |
DSV4 is MQA — 64 or 128 query heads all share the SAME K/V. This means:
- Q shape:
(batch, num_query_heads, T, head_dim)— each head has its own Q - K/V shape:
(batch, 1, s_k, head_dim)— shared across all heads - O shape:
(batch, num_query_heads, T, head_dim)— per-head output
Architecture: Two Grid Strategies
Strategy A: Head-Packed M Dimension (RECOMMENDED for decode)
Fold the head dimension into M. Each CTA processes ALL heads' queries for its M tile.
Q reshaped: (batch, T * num_query_heads, head_dim) — heads packed into M
K/V: (batch, s_k, head_dim) — shared, loaded once
Grid: (ceil_div(T * n_h, 128), 1, batch)
At decode T=1, n_h=128: grid = (1, 1, batch) — single CTA per batch!
Pros:
- K/V loaded once per CTA, shared across all heads in the M tile
- Maximum arithmetic intensity (128 heads × 1 token = 128 rows per tile)
- Simplest code change — just reshape Q and adjust grid
- The current kernel already processes M=128 rows; at T=1+n_h=128, M=128 fits exactly
Cons:
- At larger T, M = T * n_h gets large. T=64, n_h=128 → M=8192 → 64 CTA tiles. But each tile still shares K/V.
- The M tile covers multiple heads' rows. The softmax operates per-row, so each row is one head's attention. This is fine — softmax already operates per-row.
Key constraint: T * n_h must be a multiple of the M tile size (128). At T=1, n_h=128 → M=128 ✅. At T=2, n_h=128 → M=256 (2 tiles). At T=3, n_h=128 → M=384 (3 tiles). Always a multiple since n_h is a multiple of 128 and M_tile=128.
Wait — T * n_h = 3 * 128 = 384. 384/128 = 3 tiles. Each tile has 128 rows. The first 128 rows are heads 0-127 of token 0. The next 128 rows are heads 0-127 of token 1. The third tile is heads 0-127 of token 2. This works, but the Q TMA needs to know that rows come from different (token, head) pairs. The Q tensor layout must be (T * n_h, head_dim) in M-major order.
Actually, for DSV4 decode, T is typically 1 (single token). So M = n_h = 128 (Flash: 64, Pro: 128). At T=1, Flash needs M=64 (half a tile) and Pro needs M=128 (full tile). Flash with M=64 is a problem — the MMA tile is 128 rows. We'd need to handle the edge case or pad.
For now, start with Strategy B (head as grid dim) and consider Strategy A as an optimization.
Strategy B: Head as Grid Dimension (CUTLASS reference approach)
Each CTA handles one query head. Grid = (num_M_tiles, num_query_heads, batch).
Q shape: (batch, num_query_heads, T, head_dim) — per-head TMA
K/V: (batch, 1, s_k, head_dim) — shared (each CTA loads its own copy)
O shape: (batch, num_query_heads, T, head_dim) — per-head output
Grid: (ceil_div(T, 128), num_query_heads, batch)
At decode T=1, n_h=128: grid = (1, 128, batch) — 128 CTAs, each with 1 Q row (padded to 128)
Pros:
- Matches the CUTLASS reference FMHA architecture
- Each CTA is independent — no inter-CTA coordination needed
- Tile scheduler from
fmha_helpers.pyhandles work distribution - Simple, well-understood pattern
Cons:
- At decode T=1, each CTA only processes 1 query row (padded to 128). Wastes 99% of MMA compute.
- K/V loaded 128 times (once per CTA) instead of once. Wastes HBM bandwidth.
- 128 CTAs competing for GPU resources.
Despite the waste, this is the right starting point because:
- It matches the CUTLASS reference — we can copy patterns
- It's correct — we can optimize later with Strategy A
- Decode T=1 with n_h=128 and M=128 means each CTA has M=1 padded to 128 — the MMA still works, just underutilized
Implementation Plan
Step 1: Tensor Layout Definitions
Define the Q, K, V, O tensor shapes for multi-head multi-batch.
# Q: (batch, n_h, T, head_dim) BF16
# K: (batch, 1, s_k, head_dim) BF16 (MQA: shared KV)
# V: (batch, 1, s_k, head_dim) BF16 (MQA: shared KV)
# O: (batch, n_h, T, head_dim) BF16
# LSE: (batch, n_h, T) FP32
The CuTeDSL TMA tensors need to be 4D for Q/O (batch, heads, seq, dim) and 3D for K/V (batch, seq, dim). The head dimension is handled by the grid — each CTA gets one head.
Key: K/V have num_kv_heads=1 in the head dimension. The TMA atom for K/V is created from the full tensor (with head dim=1), and each CTA loads from head 0.
Step 2: Grid Shape Computation
M_tile = 128 # rows per CTA
num_M_tiles = math.ceil(T / M_tile)
grid = (num_M_tiles, num_query_heads, batch)
At decode (T=1): grid = (1, 128, batch).
At prefill (T=64): grid = (1, 128, batch) (T=64 < 128, still 1 M tile per head).
Step 3: Block Coordinate → Tensor Index Mapping
Inside the kernel, map (block_idx_x, block_idx_y, block_idx_z) to:
m_tile_idx = block_idx_x→ which M tile within this headhead_idx = block_idx_y→ which query headbatch_idx = block_idx_z→ which batch element
Then Q TMA loads from Q[batch_idx, head_idx, m_tile_idx*M_tile:(m_tile_idx+1)*M_tile, :].
K/V TMA loads from K[batch_idx, 0, :, :] (head 0, all tokens).
O TMA stores to O[batch_idx, head_idx, m_tile_idx*M_tile:(m_tile_idx+1)*M_tile, :].
Step 4: TMA Tensor Construction
The TMA tensors are constructed from the PyTorch tensors before launch. For multi-head, Q needs to be indexed by head. CuTeDSL's make_tiled_tma_atom_A creates a TMA descriptor from the tensor shape.
Key question: Can the TMA descriptor be created once and reused across heads? Yes — if Q is contiguous in the head dimension, the TMA descriptor covers all heads, and each CTA indexes into its head via the coordinate.
Actually, the CUTLASS reference creates the TMA descriptor from the FULL tensor shape (batch, n_h, T, head_dim) and the CTA's block coordinate selects the right head/tile.
Step 5: Per-CTA Q Offset
Each CTA needs to load Q for its specific head. The TMA load already handles batch and M-tile indexing. For the head dimension, the TMA tensor has a mode for it.
In the kernel:
bidx, bidy, bidz = cute.arch.block_idx()
# Q tile: m_tile=bidx, head=bidy, batch=bidz
The TMA copy: cute.copy(tma_q, mQ[(None, bidy, bidx, bidz)], sQ[(None, qh.index)]) — but this depends on the TMA tensor layout.
Step 6: K/V Shared Across Heads (MQA Optimization)
For MQA, all CTAs in the head dimension (same block_idx_x, same block_idx_z, different block_idx_y) load the same K/V. Two approaches:
- Independent loads (simple): Each CTA loads its own K/V. Wastes bandwidth but correct.
- Cluster-wide load (optimized): Use
cluster_shape_mn = (1, num_query_heads, 1)so all heads in a cluster share the same K/V SMEM. Requires cluster barriers.
Start with independent loads. The K/V are small at decode (s_k=128 or 512, head_dim=512). Even 128× loads of 128×512 BF16 = 16MB is fine.
Step 7: Output TMA Store
O has shape (batch, n_h, T, head_dim). Each CTA writes its head's output. The TMA store uses the same block coordinate mapping.
To-Do List
✅ Completed (Per-Head Launch Approach)
-
D2.2: Create test file
tests/unit/test_d2_perhead.py- Per-head launch: kernel with grid=(1,1,1), Python iterates over heads/batches
- Verified: n_h=1,2,8,16,64 at hd=64; n_h=2,8 at hd=128; n_h=2 at hd=256
- All cos ≥0.999998
-
D2.7 + D2.8: Multi-head correctness across configs
- n_h=64, batch=1, hd=64 → Flash decode config — PASS
- hd=128, n_h=8 — PASS
- hd=256, n_h=2 — PASS
🟡 Blocked (Multi-CTA Grid)
-
D2.1: Add
num_query_headsandbatch_sizetoFmhaKernel.__init__- Simple to add, but the grid change is blocked (see below)
-
D2.3–D2.6: Multi-CTA grid with runtime block coordinates
- BLOCKED:
cute.local_tiledoes not support runtime coordinates. Must usecute.flat_divideinstead. - BLOCKED:
flat_divide+epilogue_tma_storelayout mismatch. The epilogue pipeline expectstCgCfromlocal_tile, butflat_divideproduces a different coordinate system. - Requires: Full refactor of
tma_partition+epilogue_tma_storeto work withflat_divide-based GMEM views. This means moving ALL GMEM tensor partitioning into the kernel (like CUTLASS reference does). - CUTLASS reference approach: Uses
flat_divide+tma_partitioninside the TMA warp block, and a custom epilogue that handles the flat_divide coordinate system. Estimated 1-2 day effort.
- BLOCKED:
-
D2.9: LSE output for multi-head
- Per-row LSE verified correct (max err 0.000001) but CuTe tensor indexing needs work
- Currently only row 0 is written (sfw_idx==0 guard)
- Full per-row output needed for D5 KV merge
Key Technical Decisions
Q Tensor Layout for TMA
The Q tensor must be laid out so the TMA can efficiently load per-head tiles. Two options:
Option 1: (batch, n_h, T, head_dim) — head is a TMA mode
- TMA descriptor covers all heads
- Each CTA selects its head via the block coordinate
- This is how the CUTLASS reference works:
o_shape = (s, d, ((h_r, h_k), b))
Option 2: (batch, T * n_h, head_dim) — heads packed into M
- TMA descriptor is 2D (batch, M, head_dim)
- Each CTA selects its M tile
- No head mode in TMA — simpler but can't distinguish heads in the kernel
We'll use Option 1 because it matches the CUTLASS reference and allows per-head LSE output.
Grid vs. Persistent Tile Scheduler
The CUTLASS reference uses a persistent tile scheduler for load balancing. For DSV4 decode (T=1), each CTA has exactly one tile, so a static grid suffices. For prefill with variable sequence lengths, a persistent scheduler would help. Start with static grid, add persistent scheduler later if needed.
MQA K/V Sharing
At decode T=1, n_h=128: 128 CTAs each load the same K/V. Each K/V tile is s_k × head_dim BF16 = 128 × 512 × 2 = 128KB. Total K/V read = 128 × 128KB = 16MB. HBM bandwidth on B200 is ~8TB/s, so 16MB takes ~2μs. The MMA compute per CTA is ~5μs. So K/V redundancy is not a bottleneck at decode.
At prefill T=64, s_k=1024: K/V = 1024 × 512 × 2 = 1MB per CTA. 128 CTAs × 1MB = 128MB. Still ~16μs, comparable to compute. Not a bottleneck.
Start with independent K/V loads. Optimize later with cluster-wide sharing if profiling shows it's needed.
Risks and Mitigations
| Risk | Mitigation |
|---|---|
| TMA tensor layout mismatch for multi-head Q | Print shapes at trace time. Start with n_h=1 regression. |
| Grid dimension confusion (which axis is heads) | Follow CUTLASS reference: (M_tiles, heads, batch) |
| LSE output indexing for multi-head | Each CTA writes LSE to its (batch, head) position |
| SMEM budget unchanged per CTA | Each CTA has independent SMEM. No budget change. |
| Edge case: T not divisible by M_tile | For decode T=1, M=1 padded to 128. Need to mask output rows > T. |
CUTLASS Reference FMHA Architecture (for reference)
The CUTLASS Blackwell FMHA (fmha.py) uses:
- Grid:
(num_M_tiles, num_q_heads, batch)computed viafmha_utils.compute_grid - TMA shapes: Q/K/V/O as
(seq, dim, ((h_r, h_k), batch)) - Tile scheduler:
FmhaStaticTileSchedulerwith persistent mode - 12-warp layout: Separate TMA, MMA, softmax, correction, epilogue warps
- Per-CTA: One head's M tiles. GQA support via
h_randh_kin the grid.
Key files:
/root/nvidia-meeting/venv/lib/python3.10/site-packages/flashinfer/data/cutlass/examples/python/CuTeDSL/blackwell/fmha.py— main kernel/root/nvidia-meeting/venv/lib/python3.10/site-packages/flashinfer/data/cutlass/examples/python/CuTeDSL/helpers/fmha_helpers.py— grid/scheduler
Our kernel uses a simpler 6-warp layout and no persistent scheduler. We'll add the multi-head grid on top of our existing architecture.
Dependencies
- D1 (parameterized HEAD_DIM): ✅ DONE for hd≤256. hd=512 blocked by MLIR compilation but not needed for D2 development.
- D5a (un-normalized O + LSE): ✅ DONE. Needed for D5 merge but not for D2.
Next Stage After D2
D3 — SWA sequence length mask. Once we have multi-head multi-batch, add the SWA window mask (swa_lens per batch).