STAGE_D.md: restructure with correctness gaps, TMEM budget, execution order
This commit is contained in:
427
STAGE_D.md
427
STAGE_D.md
@@ -33,6 +33,7 @@
|
||||
- **PV A-operand major mode is `OperandMajorMode.K` for TMEM-P.** Not `a_major` from Q.
|
||||
- **`tOrP0` uses 3-dim indexing `(None, None, kb)`, NOT 4-dim `(None, None, kb, 0)`.** The 4th mode was already sliced away by `tOrP_base[(None,None,None,0)]`.
|
||||
- **After every P store to TMEM, call `cute.arch.fence_view_async_tmem_store()`.** Missing this produces NaN.
|
||||
- **PRINT THE SHAPES. ALWAYS.** Run `print(f"tensor: shape={cute.shape(tensor)}")` inside `@cute.kernel` at trace time. Reasoning about layouts without evidence is how we waste days.
|
||||
|
||||
---
|
||||
|
||||
@@ -40,15 +41,14 @@
|
||||
|
||||
**File:** `dsv4/kernels/attention/fmha.py`
|
||||
**Class:** `FmhaKernel`
|
||||
**State:** Exact copy of Stage C test. Works at hd=64 only. cos 0.972537 at n=128.
|
||||
**State:** Parameterized `head_dim` (D1.0 done). TMEM-P path works at hd=64 (cos 0.972537). SMEM-P path is a stub that zeros sP.
|
||||
|
||||
**What it does:**
|
||||
- 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA)
|
||||
- QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM
|
||||
- O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip
|
||||
- Epilogue: TMEM → SMEM → GMEM via TMA store
|
||||
|
||||
**Hardcoded constant that must die:** `HEAD_DIM = 64` on line 18, used in 7 places.
|
||||
- SMEM-P flag wired (`use_smem_p`), PV source switches between TMEM/SMEM, but register→SMEM copy not implemented
|
||||
|
||||
---
|
||||
|
||||
@@ -77,120 +77,391 @@ tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) # NOT pv_mma!
|
||||
|
||||
Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM.
|
||||
|
||||
**Alternative (from the FlashMLA SM100 reference):** FlashMLA keeps P in TMEM at hd≤128 using `St32x32bOp` with QK C-fragment composition (same as our Stage C). At hd>128, they'd need the SMEM path. They don't support hd>128 yet.
|
||||
---
|
||||
|
||||
## TMEM Column Budget at hd=512
|
||||
|
||||
This MUST be calculated before writing a single line of SMEM-P code.
|
||||
|
||||
**TMA tensor tensor core (TMEM) has 512 columns per CTA.** Each column is 32 bits wide.
|
||||
|
||||
At hd=64 (TMEM-P path):
|
||||
- S (QK acc): 128 cols FP32
|
||||
- P (softmax output): 64 cols FP32 (= `pv_mma_tiler[2] * BF16_width / FP32_width` = 128 * 16/32 = 64... wait, let me recalculate)
|
||||
- `p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width`
|
||||
- pv_mma_tiler = (128, 64, 128). pv_mma_tiler[2] = 128
|
||||
- p_cols_fp32 = 128 * 16 / 32 = 64
|
||||
- P starts at offset 32 (after 32 unused cols? No, S is at 0 with 128 cols, P at offset 32 overlaps??)
|
||||
- Actually: `tmem_p0_offset = 32` means P starts at TMEM col 32. But S uses cols 0-127. P at 32 means they OVERLAP. This works because S is consumed before P is written (softmax reads S, then writes P to same TMEM region).
|
||||
- After P: `o_after = max(s_cols=128, p_end=32+64=96) = 128`. `tmem_o0_offset = ((128 + 31) // 32) * 32 = 128`
|
||||
- O (PV acc): `find_tmem_tensor_col_offset(tOtO)` at hd=64 ≈ 128 cols FP32
|
||||
- Total: 128 (O offset) + 128 (O size) = 256 cols. Fits in 512. ✅
|
||||
|
||||
At hd=512 (SMEM-P path):
|
||||
- P is NOT in TMEM. S and O share TMEM (sequential, not concurrent).
|
||||
- S (QK acc): 128 cols FP32 (same as hd=64 — QK is always (128, 128))
|
||||
- O (PV acc): at hd=512, PV is (128, 512). PV MMA C-fragment is (128, 512) FP32 = 512 cols? NO.
|
||||
- `tOtO = pv_thr.make_fragment_C(pv_as)` where `pv_as = pv_thr.partition_shape_C((128, 512))`
|
||||
- The C-fragment for a tcgen05 MMA with shape (128, 512) in FP32:
|
||||
- M=128 → 4 warps × 32 threads = 128 rows, each thread owns 1 row
|
||||
- N=512 → 512/32 = 16 TMEM columns per thread? No, tcgen05 MMA writes (32, 32) tiles.
|
||||
- For (128, 512) MMA: 4 M-tiles × 16 N-tiles = 64 (32×32) subtiles
|
||||
- Each subtile uses 32 TMEM columns. But they're distributed across warps.
|
||||
- `find_tmem_tensor_col_offset(tOtO)` gives the actual footprint.
|
||||
- **MUST PRINT THIS ON THE B200.** Do not guess. Run a shape probe.
|
||||
- If O needs ~512 cols: S (128) + O (512) = 640 > 512. **DOES NOT FIT.**
|
||||
- Fix options:
|
||||
1. Drop `kv_stage` from 2 to 1 — frees SMEM but loses K/V double-buffering. TMEM budget unchanged.
|
||||
2. Split O into halves: process (128, 256) PV twice, each O tile is 256 cols. S(128) + O(256) = 384 < 512. ✅
|
||||
3. Process S and O sequentially: after softmax consumes S, O can reuse S's TMEM region. O at offset 0, 512 cols. Total = 512. ✅ But only if we don't need S anymore when writing O (true — softmax is done before PV starts per KV tile).
|
||||
|
||||
**Plan: SMEM-P path reuses S's TMEM for O.** After softmax reads S and writes P to SMEM, S's TMEM region (cols 0-127) is dead. PV writes O starting at col 0. O at hd=512 needs ~256-512 cols (must measure). If O fits in cols 0-511 with S gone, we're golden.
|
||||
|
||||
**Action item: Run shape probe on B200 before coding SMEM-P at hd=512.**
|
||||
|
||||
```python
|
||||
# Shape probe script to run on B200:
|
||||
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.nvgpu.tcgen05 as tcgen05
|
||||
from cutlass import BFloat16, Float32, LayoutEnum
|
||||
|
||||
a_major = LayoutEnum.ROW_MAJOR # adjust to match
|
||||
b_major = LayoutEnum.ROW_MAJOR
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,512), tcgen05.OperandSource.SMEM)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_as = pv_thr.partition_shape_C((128, 512))
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
||||
o_cols = find_tmem_tensor_col_offset(tOtO)
|
||||
print(f"hd=512 PV C-fragment: pv_as={pv_as}, tOtO.layout={tOtO.layout}, o_cols={o_cols}")
|
||||
# Also print tOtO shape
|
||||
print(f"tOtO shape: {cute.shape(tOtO)}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Stage D TODO List
|
||||
## Correctness Gaps — Must Close Before Production
|
||||
|
||||
### D1.0 — Replace `HEAD_DIM = 64` with constructor parameter ✅ (next step)
|
||||
These are NOT optimization gaps. They are cases where the current code produces **numerically wrong outputs** vs the trained checkpoint.
|
||||
|
||||
- [ ] Add `head_dim` to `FmhaKernel.__init__()`
|
||||
- [ ] Replace all 7 uses of `HEAD_DIM` with `self.head_dim`
|
||||
- [ ] Keep `use_smem_p=False` as default (TMEM-P path)
|
||||
- [ ] **Test:** hd=64, n=128 → cos 0.972537 (must match exactly)
|
||||
- [ ] **Test:** hd=64, n=256 → cos 0.792775 (must match exactly)
|
||||
- [ ] **DO NOT add SMEM-P code yet.** Just parameterize. Test first.
|
||||
### CG-1: SwiGLU Clamping Missing from Fused Kernel ⚠️ CRITICAL
|
||||
|
||||
The 7 places `HEAD_DIM` is used:
|
||||
1. `__init__`: `1.0 / math.sqrt(HEAD_DIM)` → `1.0 / math.sqrt(head_dim)`
|
||||
2. `_setup`: `self.pv_mma_tiler = (128, HEAD_DIM, ...)` → `(128, self.head_dim, ...)`
|
||||
3. `_setup`: `self.cta_tile_shape_mnk = (..., HEAD_DIM, ...)` → `(..., self.head_dim, ...)`
|
||||
4. `__call__`: `cute.make_layout((HEAD_DIM, self.s_k, 1), stride=(1, HEAD_DIM, HEAD_DIM * self.s_k))`
|
||||
5. `__call__`: `pv_mma = ... (128, HEAD_DIM) ...`
|
||||
6. softmax: `n_corr_tiles = HEAD_DIM // corr_tile_size`
|
||||
7. (Check for any others: `grep HEAD_DIM dsv4/kernels/attention/fmha.py`)
|
||||
**What:** `FusedSwiGLUScaledGroupedGemmKernel.__init__` stores `self.swiglu_limit` but the SwiGLU compute block (lines 2185–2200 in `fused_swiglu.py`) **never references it**. The reference path in `dsv4/reference/moe_pipeline.py` correctly applies `clamp(max=swiglu_limit)` to gate and `clamp(min=-limit, max=+limit)` to up. The fused kernel silently skips it.
|
||||
|
||||
### D1.1 — Add SMEM-P path behind `use_smem_p` flag
|
||||
**Why it matters:** Paper §4.2.3 explicitly says weights were trained with the gate component capped at 10 and the linear component clamped to [−10, 10]. Without clamping, the fused kernel produces different outputs than the reference at large activation values.
|
||||
|
||||
- [ ] Add `use_smem_p` to `__init__` (default: `head_dim > 64`)
|
||||
- [ ] In `_setup`: conditional TMEM layout (TMEM-P has `tmem_p0_offset=32`, SMEM-P has `tmem_p0_offset=-1` and `tmem_o0_offset=0`)
|
||||
- [ ] In `_setup`: allocate `p_smem_s` for SMEM-P (PV A-operand SMEM layout)
|
||||
- [ ] In `__call__`: `pv_mma` uses `OperandSource.SMEM` when `use_smem_p`, `OperandSource.TMEM` otherwise
|
||||
- [ ] In `__call__`: PV A-operand major mode is `a_major` for SMEM-P, `OperandMajorMode.K` for TMEM-P
|
||||
- [ ] **CuTeDSL scoping:** Define ALL variables unconditionally before any `if use_smem_p` blocks. Both `tOrP0` (TMEM) and `tCrP` (SMEM) must exist before the warp-branching starts.
|
||||
- [ ] **Test:** hd=64, n=128, `use_smem_p=False` → cos 0.972537 (regression)
|
||||
**Fix (2 lines in the fused kernel):**
|
||||
```python
|
||||
# After computing silu_result (gate subtile):
|
||||
silu_result = cute.math.fmin(silu_result, swiglu_limit)
|
||||
|
||||
### D1.2 — Implement register→SMEM copy for P (the hard part)
|
||||
# Before the gate*up multiply (up subtile):
|
||||
acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit)
|
||||
```
|
||||
|
||||
**Where:** `dsv4/kernels/gemm/fused_swiglu.py`, lines ~2192 (gate branch) and ~2198 (up branch).
|
||||
|
||||
**Status:** 🔴 NOT FIXED. Do this IMMEDIATELY — it's a 2-line fix and affects all MoE layer outputs.
|
||||
|
||||
### CG-2: FMHA at hd=512 SMEM-P is a Stub ⚠️ CRITICAL
|
||||
|
||||
**What:** `FmhaKernel` with `use_smem_p=True` zeros `sP` and comments "PV will produce garbage." DSV4 head_dim is 512 (§4.2.1). The kernel literally cannot produce correct output at the production head dimension.
|
||||
|
||||
**Why it matters:** This is the D1 work. The path forward is correct (`make_tiled_copy_C(store_atom, qk_mma)` to partition P registers for SMEM staging). But TMEM column budget at hd=512 must be verified first (see budget section above).
|
||||
|
||||
**Status:** 🔴 D1.2–D1.3 TODO. This document IS the plan.
|
||||
|
||||
### CG-3: SWA + Sink Merge Not Fused in FMHA ⚠️ CRITICAL
|
||||
|
||||
**What:** The DSV4 attention design (§2.3.3) requires merging compressed top-k attention with sliding window attention via sink weights. Currently the FMHA kernel only does dense attention (one KV source, one softmax, one PV, normalize). The sink merge is implemented in Python fallback (`decode_sparse.py`) but NOT in the production kernel path.
|
||||
|
||||
**Why it matters:** Without SWA+sink merge, the compressed branch alone cannot capture local dependencies. The paper is explicit: "Additional Branch of Sliding Window Attention." Every CSA and HCA layer produces wrong output without this.
|
||||
|
||||
**Fix plan (D5, ordered by priority):**
|
||||
1. **D5a:** Emit un-normalized `o` + `lse` instead of normalized `o`. This is the SINGLE MOST IMPORTANT structural change — once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness. Keep `normalize` as a flag so standalone tests still work.
|
||||
2. **D5b:** Run kernel twice externally (compressed_kv + swa_kv), merge in Python. End-to-end correctness without touching kernel structure. This is the correctness baseline.
|
||||
3. **D5c:** Fuse two passes into one kernel launch (Q stays in SMEM, two sequential MMA loops). Pure optimization.
|
||||
4. **D5d:** Fuse sink merge into kernel epilogue. Pure optimization.
|
||||
|
||||
**Status:** 🔴 D5 TODO. D5a must be done FIRST — it unblocks D5b which gives us correctness.
|
||||
|
||||
### CG-4: Inverse RoPE Verification ⚠️ HIGH
|
||||
|
||||
**What:** `inverse_rope_bf16` in `dsv4/ops/rope.py` applies the conjugate rotation to the last `rope_dim=64` dims of each head output. The math looks correct: `inv[2i] = x[2i] * cos + x[2i+1] * sin`, `inv[2i+1] = -x[2i] * sin + x[2i+1] * cos`. This is the standard inverse rotation for interleaved (GPT-J) RoPE.
|
||||
|
||||
**What needs verifying:**
|
||||
1. The `positions` argument must be the **same** positions used for the forward RoPE on Q and K. The inverse RoPE applies RoPE with position = +position (not -position). The "inverse" is the conjugate rotation, not a negated angle. The code uses `cos_sin_cache[positions, :]` which is the same table as forward RoPE. For conjugate rotation, we need cos(θ) and sin(θ) at the SAME position, then flip the sign on the sin terms in the odd positions. The current code does this correctly: `inv_odd = -o_even * sin_all + o_odd * cos_all`. ✅
|
||||
2. The `nope_dim=448` / `rope_dim=64` split must match the model's actual split. If a layer uses a different split, the inverse RoPE would rotate the wrong dims.
|
||||
3. The cos_sin_cache must be the **same** cache used for forward RoPE. If there's any offset or indexing difference, the angles won't match.
|
||||
|
||||
**Action:** Write a unit test that: (1) applies forward RoPE to random input, (2) applies inverse RoPE, (3) verifies the result matches the original. This is a round-trip test and catches both sign and indexing errors.
|
||||
|
||||
**Status:** 🟡 Code looks correct but UNTESTED. Add a round-trip unit test.
|
||||
|
||||
### CG-5: Mixed-Precision KV (BF16 RoPE + FP8 NoPE) — FMHA Load Path ⚠️ HIGH
|
||||
|
||||
**What:** Paper §2.3.4: KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The `PagedKVPool` already implements this split: `entries_fp8` (uint8) + `entries_rope` (BF16) + `inv_scale` (FP32). The current decode_sparse.py fallback dequantizes in Python before calling the kernel.
|
||||
|
||||
**Why it matters for FMHA:** The FmhaKernel currently takes contiguous BF16 K/V tensors. At production, the kernel must handle the mixed-precision KV directly — reading FP8 + BF16 from the paged cache and dequantizing on the fly during TMA→SMEM transfer. This is the proper Blackwell pattern: TMA loads FP8 to SMEM, on-the-fly dequant in the SMEM→register path, then MMA.
|
||||
|
||||
**The proper approach:**
|
||||
1. TMA loads FP8 NoPE dims to SMEM slot 0
|
||||
2. TMA loads BF16 RoPE dims to SMEM slot 1 (or separate TMA)
|
||||
3. Dequantize FP8 → BF16 in SMEM (vectorized, per-entry `inv_scale` multiply)
|
||||
4. Concatenate [NoPE, RoPE] in SMEM (or use two separate SMEM regions with strided MMA)
|
||||
5. MMA reads contiguous BF16 from SMEM
|
||||
|
||||
**Prerequisite:** This requires D1 (SMEM-P) and D5 (sink merge) to be working first. The mixed-precision load path replaces the current "all BF16" K/V input with the real paged cache format.
|
||||
|
||||
**Status:** 🔴 NOT IMPLEMENTED. Plan as D6 (after D5). The current test harness passes contiguous BF16 K/V, which is fine for correctness testing. The FP8 dequant in SMEM is a performance + memory optimization that doesn't affect numerical correctness (FP8 dequant is well-defined).
|
||||
|
||||
### CG-6: Per-Token valid_lens in Indexer for Prefill ⚠️ MEDIUM
|
||||
|
||||
**What:** `score_topk.py` has a `TODO` that broadcasts request 0's `valid_lens` for prefill (T > B). For batched prefill, different requests have different numbers of compressed entries in the pool. Broadcasting the first request's count means other requests either score garbage entries (too many) or miss valid ones (too few).
|
||||
|
||||
**Why it matters:** Prefill correctness blocker. The indexer will select wrong entries for all requests except the first in a batch.
|
||||
|
||||
**Fix:** Map each query token to its request ID, then look up `valid_lens[request_id]`. The `request_ids: [T] int32` tensor already exists in the cache handle. The indexer kernel needs this as an input.
|
||||
|
||||
**Status:** 🔴 NOT FIXED. This is indexer scope, not FMHA scope. Track separately.
|
||||
|
||||
---
|
||||
|
||||
## Performance Soft Spots — Important But Not Correctness
|
||||
|
||||
These affect throughput but not numerical correctness. Tracked for Stage F+.
|
||||
|
||||
### PS-1: Indexer Score+TopK is Scalar CUDA — Not Blackwell Native 🔴
|
||||
|
||||
**What:** `indexer_score_topk.cu` is a CUDA-core scalar implementation:
|
||||
- Triple loop: `for h in n_heads, for g in n_groups, for b in 8`
|
||||
- FP4 nibble dequant to FP32, FP32 dot product
|
||||
- Shared-memory min-heap protected by single `s_lock` atomicCAS spinlock
|
||||
- For 1M-context: ~250K compressed entries scored per query token
|
||||
|
||||
**Why it's the biggest perf leak:** The dot products should use tensor cores. The heap spinlock won't scale to top_k=1024 with hundreds of thousands of candidates.
|
||||
|
||||
**The correct approach:** DeepGEMM's `fp8_mqa_logits` / `fp8_paged_mqa_logits` pattern (Sept 2025 PR for V3.2 indexer). Weighted ReLU MQA logits computed with tensor cores, paged variant for decode. Our V4 NVFP4 variant should be that pattern with FP4 inputs and tcgen05 MMA. Beyond the MMA, the heap needs replacing with per-warp partial top-k merged via reduction tree, or radix-select.
|
||||
|
||||
**Status:** Tracked for Stage F (post Stage E). Not blocking D1–D5.
|
||||
|
||||
### PS-2: decode_sparse.py BlackwellSparseDecodeKernel is Misleading 🟡
|
||||
|
||||
**What:** `dsv4/ops/decode_sparse.py` contains `BlackwellSparseDecodeKernel` — a CuTeDSL kernel that does scalar `for d in range(HD): dot += q_val * k_val` with no tensor cores. It also has a `_fallback_sparse_sdp` Python path that uses `F.scaled_dot_product_attention`.
|
||||
|
||||
**Why it's misleading:** The class name says "Blackwell" but it uses zero Blackwell tensor acceleration. Anyone reading the codebase would assume this is the production kernel. It's a stale early-exploration kernel superseded by `FmhaKernel`.
|
||||
|
||||
**Action:** Delete `BlackwellSparseDecodeKernel` and its CuTeDSL code. Keep `_fallback_sparse_sdp` as a reference implementation (rename to `_reference_sparse_sdp_attention`). The FMHA kernel in `dsv4/kernels/attention/fmha.py` is the real path. Do this cleanup as part of E7.
|
||||
|
||||
**Status:** Low urgency. Track for E7 cleanup.
|
||||
|
||||
### PS-3: mHC Mixing Uses torch.bmm with n_hc=4 🟢
|
||||
|
||||
**What:** mHC mixing operations (`A_l @ X_l`, `B_l @ X_l`, `C_l ⊗ F_out`) use `torch.bmm` with tiny `n_hc=4` inner dimension.
|
||||
|
||||
**Why it matters:** For decode (T=1) this is fine — tiny matmul. For prefill it leaves throughput on the floor. But prefill is not the immediate priority.
|
||||
|
||||
**Status:** Lowest priority of the soft spots. Track for Stage G (prefill optimization).
|
||||
|
||||
---
|
||||
|
||||
## Stage D Build Order (REVISED)
|
||||
|
||||
### Priority Principle: Correctness First, Then Performance
|
||||
|
||||
D1 (hd=512) and D5 (SWA+sink merge) are both correctness-critical. But D5 depends on D1 (can't merge SWA if the kernel can't even run at hd=512). CG-1 (SwiGLU clamping) is a 2-line fix with no dependencies — do it first.
|
||||
|
||||
### D0 — SwiGLU Clamping Fix (CG-1) ⚡ DO THIS FIRST
|
||||
|
||||
- [ ] Add clamping to fused SwiGLU in `dsv4/kernels/gemm/fused_swiglu.py`
|
||||
- [ ] Gate subtile: `silu_result = cute.math.fmin(silu_result, swiglu_limit)` after SiLU compute
|
||||
- [ ] Up subtile: `acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit)` before gate*up multiply
|
||||
- [ ] Verify: `cute.math.fmin` / `cute.math.fmax` work with CuTeDSL vectorized code (they should — they're elementwise)
|
||||
- [ ] Test: fused MoE output matches reference with clamping at swiglu_limit=10.0
|
||||
- [ ] Commit with clear message: "fix: add SwiGLU clamping to fused kernel (paper §4.2.3)"
|
||||
|
||||
### D1 — Parameterized HEAD_DIM + SMEM-P (CG-2)
|
||||
|
||||
#### D1.0 — Replace HEAD_DIM constant with constructor parameter ✅ DONE
|
||||
|
||||
Already in the kernel. `head_dim` is a constructor arg. TMEM-P path works at hd=64.
|
||||
|
||||
#### D1.1 — Add SMEM-P path behind `use_smem_p` flag ✅ WIRED (stub)
|
||||
|
||||
The `use_smem_p` flag exists. PV source switches between TMEM/SMEM. TMEM layout adjusts. But the register→SMEM copy is a stub that zeros sP.
|
||||
|
||||
#### D1.2 — TMEM Column Budget Verification 🔨 DO THIS BEFORE CODING
|
||||
|
||||
- [ ] Run shape probe on B200: `find_tmem_tensor_col_offset(tOtO)` at hd=512
|
||||
- [ ] Print `pv_as`, `tOtO.layout`, `o_cols` at hd=128, 256, 512
|
||||
- [ ] Calculate: can S(128) and O(???) share TMEM at hd=512?
|
||||
- [ ] If O > 384 cols: plan for split-PV (two (128, 256) passes)
|
||||
- [ ] Document the budget numbers HERE in this file
|
||||
|
||||
#### D1.3 — Implement register→SMEM copy for P (THE HARD PART)
|
||||
|
||||
- [ ] Build `tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)` — QK MMA partitions threads
|
||||
- [ ] **Print the shapes:** `cute.shape(tiled_p_copy)`, partition source/dest shapes
|
||||
- [ ] Partition `sP` with `tiled_p_copy` as destination
|
||||
- [ ] In softmax warps: after computing P in registers, write to SMEM via `tiled_p_copy`
|
||||
- [ ] Add `p_smem_ready_bar` barrier: softmax arrives after write, MMA waits before PV GEMM
|
||||
- [ ] Add `p_smem_ready_bar` NamedBarrier: softmax arrives after write + fence, MMA waits before PV GEMM
|
||||
- [ ] In MMA warp: read P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)`
|
||||
- [ ] **Test:** hd=64, n=128, `use_smem_p=True` → compare against TMEM-P result (should be close)
|
||||
- [ ] **Test:** hd=64, n=128, `use_smem_p=True` → compare against TMEM-P result
|
||||
- [ ] **Test:** hd=128, n=128 → test against FP32 oracle
|
||||
- [ ] **Test:** hd=256, n=128 → test against FP32 oracle
|
||||
- [ ] **Test:** hd=512, n=128 → test against FP32 oracle (DSV4's real value)
|
||||
|
||||
### D1.3 — Multi-PV-tile for hd>256
|
||||
#### D1.4 — Multi-PV-tile for hd>256
|
||||
|
||||
- [ ] When `head_dim > 256`, the MMA instruction can only process 256 columns at a time
|
||||
- [ ] `pv_n_tile = min(head_dim, 256)`, `n_pv_tiles = head_dim // pv_n_tile`
|
||||
- [ ] Multiple PV GEMM passes per KV tile, accumulating O
|
||||
- [ ] V must be re-constructed with `v_n = pv_n_tile` per pass
|
||||
- [ ] This may require multiple kernel launches at Python level (or a loop inside the kernel)
|
||||
- [ ] Add `pv_n_tile = min(head_dim, 256)` and `n_pv_tiles = head_dim // pv_n_tile` to `__init__`
|
||||
- [ ] For hd=512: 2 PV tiles of (128, 256) each
|
||||
- [ ] Strategy: kernel processes one PV N-tile per launch. Python orchestrates the tiles.
|
||||
- Pass 0: V[:, 0:256] → output[:, 0:256], QK + softmax + PV for cols 0-256
|
||||
- Pass 1: V[:, 256:512] → output[:, 256:512], QK + softmax + PV for cols 256-512
|
||||
- QK and softmax run identically both passes (P is the same). Only PV changes.
|
||||
- [ ] Alternative (if SMEM-P allows): keep P in SMEM between PV tiles. Run QK+softmax once, PV twice.
|
||||
- [ ] **Test:** hd=512, n=128 → correct output against FP32 oracle
|
||||
|
||||
### D1.4 — Cleanup and regression
|
||||
#### D1.5 — Correction Epilogue: Fix TMEM Layout Mismatch (3% Error)
|
||||
|
||||
- [ ] Remove `HEAD_DIM = 64` constant entirely
|
||||
- [ ] Add `head_dim` as first constructor arg (no default — always explicit)
|
||||
- [ ] Default `use_smem_p=None` → auto-detect from `head_dim > 64`
|
||||
- [ ] Test matrix: hd ∈ {64, 128, 256, 512} × n ∈ {128, 256}
|
||||
- [ ] Update README status table: D1 → ✅ COMPLETE
|
||||
- [ ] Cross off D1.0–D1.4 in this file
|
||||
The current TMEM round-trip (Ld32x32bOp + St32x32bOp hand-constructed atoms) introduces 3% error at hd=64 (cos 0.973). The proper fix is the CUTLASS `correction_epilog` pattern:
|
||||
|
||||
---
|
||||
```
|
||||
TMEM --get_tmem_load_op--> reg (normalize + FP32→BF16) --get_smem_store_op--> SMEM --TMA--> GMEM
|
||||
```
|
||||
|
||||
## D2 — Multi-query grid with head packing (after D1)
|
||||
This is a one-way trip. No TMEM round-trip. No layout mismatch.
|
||||
|
||||
- [ ] Investigate: can we use `get_tmem_load_op` + `get_smem_store_op` paired atoms?
|
||||
- [ ] Investigate: can we inject `inv_row_sum` into `epilogue_tma_store` pipeline?
|
||||
- [ ] Investigate: pre-compute TMA partitioning outside `if warp_idx` blocks (region isolation workaround)
|
||||
- [ ] **Test:** hd=64, n=128 → cos should jump from 0.973 → ~0.9999
|
||||
- [ ] **Test:** hd=64, n=256 → cos should jump from 0.793 → ~0.9999
|
||||
|
||||
**Note:** This is NOT blocking for D2–D5. The 3% error is a precision issue, not a correctness issue (the attention math is right, the epilogue just introduces rounding). Fix it properly rather than hacking it. But don't let it block the D2–D5 pipeline.
|
||||
|
||||
### D2 — Multi-Query Grid with Head Packing
|
||||
|
||||
- [ ] Grid changes from `(1, 1, 1)` to `(num_q_blocks, 1, batch)`
|
||||
- [ ] DSV4 is MQA: all 128 query heads share same K/V
|
||||
- [ ] Head axis folded into M dimension of Q tile
|
||||
- [ ] **Test:** batch=4, T=64, n_h=128, num_kv_heads=1
|
||||
- [ ] Head axis folded into M dimension of Q tile: `M_tile = 128` covers `M = T * n_h` rows
|
||||
- [ ] At decode T=1: M = 1 × 128 = 128 — one Q block covers all heads. ✅
|
||||
- [ ] At prefill T=64: M = 64 × 128 = 8192 — 64 Q blocks. Needs grid loop.
|
||||
- [ ] **Test:** batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention
|
||||
|
||||
## D3 — SWA sequence length mask
|
||||
### D3 — SWA Sequence Length Mask
|
||||
|
||||
- [ ] Add `swa_lens: [batch] int32` kernel input
|
||||
- [ ] Mask SWA-branch logits to `-inf` where `swa_idx >= swa_lens[b]`
|
||||
- [ ] **Test:** varying SWA fill levels
|
||||
- [ ] **Test:** batched input with varying SWA fill levels (position 50 vs 5000)
|
||||
|
||||
## D4 — Causal mask on SWA branch
|
||||
### D4 — Causal Mask on SWA Branch
|
||||
|
||||
- [ ] Add `is_causal: bool` constructor flag
|
||||
- [ ] Apply `swa_idx > q_pos` masking in SWA pass
|
||||
- [ ] Apply `swa_idx > q_pos` masking to `-inf` in SWA pass
|
||||
- [ ] Main path has NO mask (indexer enforces causality upstream)
|
||||
- [ ] **Test:** prefill mode produces correct output with causal mask
|
||||
|
||||
## D5 — SWA + sink merge
|
||||
### D5 — SWA + Sink Merge (CG-3) ⚠️ THE WHOLE POINT OF V4 ATTENTION
|
||||
|
||||
- [ ] D5a: Emit un-normalized `o` + `lse` instead of normalized `o` (keep normalize as flag)
|
||||
- [ ] D5b: Run kernel twice externally (compressed_kv + swa_kv), merge in Python
|
||||
- [ ] D5c: Fuse two passes into one kernel launch (Q stays in SMEM)
|
||||
- [ ] D5d: Fuse sink merge into kernel epilogue
|
||||
#### D5a — Emit un-normalized o + lse ⚡ DO THIS IMMEDIATELY AFTER D1
|
||||
|
||||
This is the single most important structural change. Once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness.
|
||||
|
||||
- [ ] Change epilogue: instead of `O *= 1/row_sum`, emit `O` un-normalized and `lse = log(row_sum) + row_max` as a separate output
|
||||
- [ ] Add `normalize: bool` constructor flag (default: True for backward compat, False for merge mode)
|
||||
- [ ] When `normalize=False`: skip the TMEM round-trip for normalize. O stays as `PV @ V` (un-normalized). lse written to a separate GMEM buffer.
|
||||
- [ ] **Test:** `normalize=True` → identical to current behavior (regression)
|
||||
- [ ] **Test:** `normalize=False` → `o_unnorm / exp(lse).unsqueeze(-1)` ≈ `o_normalized` (verify math)
|
||||
|
||||
#### D5b — Python merge (correctness baseline)
|
||||
|
||||
- [ ] Run FmhaKernel twice: once with compressed_kv, once with swa_kv
|
||||
- [ ] Merge in Python:
|
||||
```python
|
||||
exp_lse_sparse = lse_sparse.exp()
|
||||
exp_lse_swa = lse_swa.exp()
|
||||
exp_sink = sink_logits.exp()
|
||||
o = (exp_lse_sparse * o_sparse + exp_sink * exp_lse_swa * o_swa) / (exp_lse_sparse + exp_sink * exp_lse_swa)
|
||||
```
|
||||
- [ ] Test against FP32 oracle that does sparse+SWA+sink merge
|
||||
- [ ] **This gives us end-to-end correctness.** Everything after is optimization.
|
||||
|
||||
#### D5c — Fuse two passes into one kernel launch
|
||||
|
||||
- [ ] Q loaded once to SMEM, used by both compressed and SWA MMA loops
|
||||
- [ ] Two sequential QK→softmax→PV passes in one kernel invocation
|
||||
- [ ] K/V have two sources: compressed (contiguous BF16) and SWA (from cache)
|
||||
- [ ] For now: dequantize SWA in a small prep kernel before FMHA, FMHA sees two contiguous BF16 sources
|
||||
- [ **Test:** output matches D5b Python merge
|
||||
|
||||
#### D5d — Fuse sink merge into kernel epilogue
|
||||
|
||||
- [ ] TMEM holds two O accumulators + two row_max/row_sum per row
|
||||
- [ ] Verify TMEM column budget: two O + two (row_max, row_sum) at hd=512
|
||||
- [ ] Sink merge in TMEM: `O = (exp(lse1) * O1 + exp(sink) * exp(lse2) * O2) / (exp(lse1) + exp(sink) * exp(lse2))`
|
||||
- [ **Test:** output matches D5b Python merge
|
||||
|
||||
### D6 — Mixed-Precision KV Load Path (CG-5)
|
||||
|
||||
- [ ] TMA loads FP8 NoPE dims to SMEM slot 0
|
||||
- [ ] TMA loads BF16 RoPE dims to SMEM slot 1
|
||||
- [ ] Dequantize FP8 → BF16 in SMEM (vectorized `* inv_scale`)
|
||||
- [ ] Concatenate [NoPE, RoPE] in SMEM
|
||||
- [ ] MMA reads contiguous BF16 from SMEM
|
||||
- [ ] **Test:** FP8+BF16 split input matches pure BF16 input (dequant is transparent)
|
||||
- [ ] **Prerequisite:** D1 (SMEM-P) and D5 (sink merge) working first
|
||||
|
||||
---
|
||||
|
||||
## Key References
|
||||
## Inverse RoPE Verification (CG-4) — Separate from D1–D6
|
||||
|
||||
| What | Where |
|
||||
|------|-------|
|
||||
| Working FMHA kernel (hd=64) | `dsv4/kernels/attention/fmha.py` — `FmhaKernel` |
|
||||
| Stage C test (oracle) | `tests/unit/test_fmha_v3_stage_c.py` — `FmhaV3StageCMulti` |
|
||||
| Stage A+B test | `tests/unit/test_fmha_v3.py` |
|
||||
| FlashMLA SM100 reference | `/root/dsv4-nvfp4-workspace/vllm/.deps/flashmla-src/csrc/cutlass/examples/python/CuTeDSL/blackwell/fmha.py` (on B200) |
|
||||
| CUTLASS FMHA reference | `/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py` (on B200) |
|
||||
| Sink merge spec | `dsv4/ops/decode_sparse.py` |
|
||||
| SWA decode | `dsv4/ops/decode_swa.py` |
|
||||
| Attention reference | `dsv4/reference/attention.py` |
|
||||
| CSA attention reference | `dsv4/reference/csa_attention.py` |
|
||||
- [ ] Write unit test: `tests/unit/test_inverse_rope.py`
|
||||
- [ ] Round-trip test: forward RoPE → inverse RoPE → verify ≈ original
|
||||
- [ ] Multi-head test: verify only last 64 dims are rotated
|
||||
- [ ] Position test: verify cos_sin_cache indexing is correct for positions > 0
|
||||
- [ ] This is a standalone test, not a kernel change. Can be done anytime.
|
||||
|
||||
## B200 Environment
|
||||
---
|
||||
|
||||
```
|
||||
Server: root@45.76.247.107 (password: 6)Jr)B@dcX[mN?dx)
|
||||
Kernel repo: /root/dsv4-nvfp4-workspace/kernel
|
||||
Venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
|
||||
PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
|
||||
Test command: python3 tests/unit/test_fmha_v3_stage_c.py
|
||||
```
|
||||
## Per-Token valid_lens (CG-6) — Indexer Scope, Not FMHA
|
||||
|
||||
- [ ] Add `request_ids: [T] int32` to indexer kernel input
|
||||
- [ ] Look up `valid_lens[request_ids[t]]` per query token
|
||||
- [ ] Replace the current broadcast of `valid_lens[:1]`
|
||||
- [ ] **Test:** batched prefill with different sequence lengths per request
|
||||
- [ ] Tracked separately from FMHA Stage D. This is indexer work.
|
||||
|
||||
---
|
||||
|
||||
## Correctness Gap NOT in This Project
|
||||
|
||||
### CG-7: Indexer Rewrite (PS-1) — Stage F
|
||||
|
||||
The indexer needs a full rewrite from scalar CUDA to tcgen05 MMA + radix-select. This is a major work item (2-3 weeks) that is out of scope for Stage D.
|
||||
|
||||
**Reference:** DeepGEMM's `fp8_mqa_logits` / `fp8_paged_mqa_logits` (Sept 2025 PR for V3.2 indexer). Our V4 variant: same pattern with FP4 inputs and tcgen05 MMA.
|
||||
|
||||
---
|
||||
|
||||
## Execution Order (Top to Bottom)
|
||||
|
||||
| # | Task | Blocks | Est. |
|
||||
|---|------|--------|------|
|
||||
| D0 | SwiGLU clamping (CG-1) | Nothing — do first | 30 min |
|
||||
| D1.2 | TMEM budget probe at hd=512 | D1.3 | 1 hr |
|
||||
| D1.3 | Register→SMEM copy for P | D1.4, D2 | 1-2 days |
|
||||
| D1.4 | Multi-PV-tile hd>256 | D2 | 1 day |
|
||||
| D1.5 | Correction epilog fix (3% → 0.01%) | Nothing (can parallel) | 1-2 days |
|
||||
| D2 | Multi-query grid + head packing | D3 | 1 day |
|
||||
| D3 | SWA sequence length mask | D5 | ½ day |
|
||||
| D4 | Causal mask on SWA | D5 | ½ day |
|
||||
| D5a | Emit un-normalized o + lse | D5b | 1 day |
|
||||
| D5b | Python merge (correctness) | D5c | ½ day |
|
||||
| D5c | Fuse two passes in one launch | D5d | 2 days |
|
||||
| D5d | Fuse sink merge in epilogue | D6 | 2 days |
|
||||
| D6 | Mixed-precision KV load | E1 | 2 days |
|
||||
| CG-4 | Inverse RoPE round-trip test | Nothing | 2 hrs |
|
||||
| CG-6 | Per-token valid_lens (indexer) | Nothing | ½ day |
|
||||
|
||||
**Critical path:** D0 → D1.2 → D1.3 → D1.4 → D5a → D5b (end-to-end correctness)
|
||||
|
||||
**D1.5 (correction epilog) and CG-4 (RoPE test) can happen in parallel with D2–D4.**
|
||||
@@ -166,19 +166,15 @@ class FmhaKernel:
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
|
||||
|
||||
# PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally
|
||||
# When pv_mma uses TMEM source, make_fragment_A needs a TMEM-based tensor (tP from tStS).
|
||||
# When pv_mma uses SMEM source, make_fragment_A needs an SMEM-based tensor (sP).
|
||||
# We construct both paths using the appropriate tensor for make_fragment_A.
|
||||
# PV A-operand: define both tOrP0 (TMEM-P) and tCrP (SMEM-P) unconditionally.
|
||||
# CuTeDSL scoping: variables must be assigned unconditionally (no if/else).
|
||||
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
# For TMEM source PV: fragment_A from TMEM tensor tP
|
||||
tOrP_base = pv_thr.make_fragment_A(tP if not self.use_smem_p else sP)
|
||||
tOrP = tOrP_base[(None,None,None,0)]
|
||||
tOrP0 = cute.make_tensor(
|
||||
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * max(self.tmem_p0_offset, 0),
|
||||
tOrP.layout)
|
||||
# For SMEM source PV: fragment_A from SMEM tensor sP
|
||||
tCrP = pv_mma.make_fragment_A(sP)
|
||||
# tOrP0 always defined as tOrP. The TMEM-P path in the MMA warp applies
|
||||
# the p0 column offset inline when constructing the gemm arguments.
|
||||
tOrP0 = tOrP
|
||||
|
||||
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_as, self.num_acc_stage))
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
|
||||
|
||||
Reference in New Issue
Block a user