5.**Regression check:** After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert.
### The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION)
- **NEVER edit files directly on the B200.** Edit locally, commit, push, pull, test. Every time.
- **NEVER delete or modify the test files in `tests/unit/`.** They are the regression oracle.
- **NEVER touch drivers, kernels, firmware, or system packages on the B200.**
- **CuTeDSL variables defined in `if` blocks are NOT visible in other `if` blocks.** Even compile-time constants. Define all variables unconditionally before any branching.
- **Always test at hd=64 FIRST.** If the proven path (TMEM-P) regresses, nothing else matters.
- **`p_cols_fp32` uses `pv_mma_tiler[2]` (K-dim), NOT `pv_mma_tiler[1]` (N-dim).** We got this wrong twice.
- **PV A-operand major mode is `OperandMajorMode.K` for TMEM-P.** Not `a_major` from Q.
- **`tOrP0` uses 3-dim indexing `(None, None, kb)`, NOT 4-dim `(None, None, kb, 0)`.** The 4th mode was already sliced away by `tOrP_base[(None,None,None,0)]`.
- **`tOrP0` MUST include the `tmem_p0_offset` column offset.** The softmax warps store P at `tmem_p0_offset=32` (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this offset causes NaN/zeros (the MMA reads from column 0 where S is, not column 32 where P is). Use `const_expr` for the conditional: `if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP`
- **PRINT THE SHAPES. ALWAYS.** Run `print(f"tensor: shape={cute.shape(tensor)}")` inside `@cute.kernel` at trace time. Reasoning about layouts without evidence is how we waste days.
At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works.
At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts **disagree** — different threads own different columns. The register bridge can't write P in a layout that PV can read.
**The fix: SMEM-P path.** P goes through SMEM instead of TMEM:
1. Softmax computes P in registers (QK C-fragment partition)
2. Write P to SMEM using the `p_smem_s` layout (PV A-operand SMEM layout)
3. MMA warp reads P from SMEM via `tCrP = pv_mma.make_fragment_A(sP)`
4. PV GEMM uses `tcgen05.OperandSource.SMEM` instead of `OperandSource.TMEM`
**The SMEM rendezvous:** SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address.
**The missing piece (the D1 work):** The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a `TiledCopy` that partitions threads by QK's C-fragment and targets the P SMEM layout.
**tcgen05 MMA has a hard limit: N ≤ 256.** At hd=512, PV MUST be split into 2 tiles of (128, 256). The MMA rejects N=512 at construction time with `OpError: expects the N-mode to satisfy 8 <= N <= 256 and N % 8 == 0, but got 512`.
**Measured on B200:**
| hd | s_cols | pv_n_tile | o_cols | TMEM-P total | SMEM-P total |
## Correctness Gaps — Must Close Before Production
These are NOT optimization gaps. They are cases where the current code produces **numerically wrong outputs** vs the trained checkpoint.
### CG-1: SwiGLU Clamping Missing from Fused Kernel ⚠️ CRITICAL
**What:** `FusedSwiGLUScaledGroupedGemmKernel.__init__` stores `self.swiglu_limit` but the SwiGLU compute block (lines 2185–2200 in `fused_swiglu.py`) **never references it**. The reference path in `dsv4/reference/moe_pipeline.py` correctly applies `clamp(max=swiglu_limit)` to gate and `clamp(min=-limit, max=+limit)` to up. The fused kernel silently skips it.
**Why it matters:** Paper §4.2.3 explicitly says weights were trained with the gate component capped at 10 and the linear component clamped to [−10, 10]. Without clamping, the fused kernel produces different outputs than the reference at large activation values.
**Where:** `dsv4/kernels/gemm/fused_swiglu.py`, lines ~2192 (gate branch) and ~2198 (up branch).
**Status:** 🔴 NOT FIXED. Do this IMMEDIATELY — it's a 2-line fix and affects all MoE layer outputs.
### CG-2: FMHA at hd=512 SMEM-P is a Stub ⚠️ CRITICAL
**What:** `FmhaKernel` with `use_smem_p=True` zeros `sP` and comments "PV will produce garbage." DSV4 head_dim is 512 (§4.2.1). The kernel literally cannot produce correct output at the production head dimension.
**Why it matters:** This is the D1 work. The path forward is correct (`make_tiled_copy_C(store_atom, qk_mma)` to partition P registers for SMEM staging). But TMEM column budget at hd=512 must be verified first (see budget section above).
**Status:** 🟡 D1.3 SMEM-P still a stub. hd=64 TMEM-P works (cos 0.973). `make_tiled_copy_C` gives rank mismatch. Need proper layout-aware P register→SMEM copy.
### CG-3: SWA + Sink Merge Not Fused in FMHA ⚠️ CRITICAL
**What:** The DSV4 attention design (§2.3.3) requires merging compressed top-k attention with sliding window attention via sink weights. Currently the FMHA kernel only does dense attention (one KV source, one softmax, one PV, normalize). The sink merge is implemented in Python fallback (`decode_sparse.py`) but NOT in the production kernel path.
**Why it matters:** Without SWA+sink merge, the compressed branch alone cannot capture local dependencies. The paper is explicit: "Additional Branch of Sliding Window Attention." Every CSA and HCA layer produces wrong output without this.
**Fix plan (D5, ordered by priority):**
1.**D5a:** Emit un-normalized `o` + `lse` instead of normalized `o`. This is the SINGLE MOST IMPORTANT structural change — once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness. Keep `normalize` as a flag so standalone tests still work.
2.**D5b:** Run kernel twice externally (compressed_kv + swa_kv), merge in Python. End-to-end correctness without touching kernel structure. This is the correctness baseline.
3.**D5c:** Fuse two passes into one kernel launch (Q stays in SMEM, two sequential MMA loops). Pure optimization.
4.**D5d:** Fuse sink merge into kernel epilogue. Pure optimization.
**What:** `inverse_rope_bf16` in `dsv4/ops/rope.py` applies the conjugate rotation to the last `rope_dim=64` dims of each head output. The math looks correct: `inv[2i] = x[2i] * cos + x[2i+1] * sin`, `inv[2i+1] = -x[2i] * sin + x[2i+1] * cos`. This is the standard inverse rotation for interleaved (GPT-J) RoPE.
**What needs verifying:**
1. The `positions` argument must be the **same** positions used for the forward RoPE on Q and K. The inverse RoPE applies RoPE with position = +position (not -position). The "inverse" is the conjugate rotation, not a negated angle. The code uses `cos_sin_cache[positions, :]` which is the same table as forward RoPE. For conjugate rotation, we need cos(θ) and sin(θ) at the SAME position, then flip the sign on the sin terms in the odd positions. The current code does this correctly: `inv_odd = -o_even * sin_all + o_odd * cos_all`. ✅
2. The `nope_dim=448` / `rope_dim=64` split must match the model's actual split. If a layer uses a different split, the inverse RoPE would rotate the wrong dims.
3. The cos_sin_cache must be the **same** cache used for forward RoPE. If there's any offset or indexing difference, the angles won't match.
**Action:** Write a unit test that: (1) applies forward RoPE to random input, (2) applies inverse RoPE, (3) verifies the result matches the original. This is a round-trip test and catches both sign and indexing errors.
**Status:** 🟡 Code looks correct but UNTESTED. Add a round-trip unit test.
**What:** Paper §2.3.4: KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The `PagedKVPool` already implements this split: `entries_fp8` (uint8) + `entries_rope` (BF16) + `inv_scale` (FP32). The current decode_sparse.py fallback dequantizes in Python before calling the kernel.
**Why it matters for FMHA:** The FmhaKernel currently takes contiguous BF16 K/V tensors. At production, the kernel must handle the mixed-precision KV directly — reading FP8 + BF16 from the paged cache and dequantizing on the fly during TMA→SMEM transfer. This is the proper Blackwell pattern: TMA loads FP8 to SMEM, on-the-fly dequant in the SMEM→register path, then MMA.
**The proper approach:**
1. TMA loads FP8 NoPE dims to SMEM slot 0
2. TMA loads BF16 RoPE dims to SMEM slot 1 (or separate TMA)
4. Concatenate [NoPE, RoPE] in SMEM (or use two separate SMEM regions with strided MMA)
5. MMA reads contiguous BF16 from SMEM
**Prerequisite:** This requires D1 (SMEM-P) and D5 (sink merge) to be working first. The mixed-precision load path replaces the current "all BF16" K/V input with the real paged cache format.
**Status:** 🔴 NOT IMPLEMENTED. Plan as D6 (after D5). The current test harness passes contiguous BF16 K/V, which is fine for correctness testing. The FP8 dequant in SMEM is a performance + memory optimization that doesn't affect numerical correctness (FP8 dequant is well-defined).
### CG-6: Per-Token valid_lens in Indexer for Prefill ⚠️ MEDIUM
**What:** `score_topk.py` has a `TODO` that broadcasts request 0's `valid_lens` for prefill (T > B). For batched prefill, different requests have different numbers of compressed entries in the pool. Broadcasting the first request's count means other requests either score garbage entries (too many) or miss valid ones (too few).
**Why it matters:** Prefill correctness blocker. The indexer will select wrong entries for all requests except the first in a batch.
**Fix:** Map each query token to its request ID, then look up `valid_lens[request_id]`. The `request_ids: [T] int32` tensor already exists in the cache handle. The indexer kernel needs this as an input.
**Status:** 🔴 NOT FIXED. This is indexer scope, not FMHA scope. Track separately.
---
## Performance Soft Spots — Important But Not Correctness
These affect throughput but not numerical correctness. Tracked for Stage F+.
### PS-1: Indexer Score+TopK is Scalar CUDA — Not Blackwell Native 🔴
**What:** `indexer_score_topk.cu` is a CUDA-core scalar implementation:
- Triple loop: `for h in n_heads, for g in n_groups, for b in 8`
- FP4 nibble dequant to FP32, FP32 dot product
- Shared-memory min-heap protected by single `s_lock` atomicCAS spinlock
- For 1M-context: ~250K compressed entries scored per query token
**Why it's the biggest perf leak:** The dot products should use tensor cores. The heap spinlock won't scale to top_k=1024 with hundreds of thousands of candidates.
**The correct approach:** DeepGEMM's `fp8_mqa_logits` / `fp8_paged_mqa_logits` pattern (Sept 2025 PR for V3.2 indexer). Weighted ReLU MQA logits computed with tensor cores, paged variant for decode. Our V4 NVFP4 variant should be that pattern with FP4 inputs and tcgen05 MMA. Beyond the MMA, the heap needs replacing with per-warp partial top-k merged via reduction tree, or radix-select.
**What:** `dsv4/ops/decode_sparse.py` contains `BlackwellSparseDecodeKernel` — a CuTeDSL kernel that does scalar `for d in range(HD): dot += q_val * k_val` with no tensor cores. It also has a `_fallback_sparse_sdp` Python path that uses `F.scaled_dot_product_attention`.
**Why it's misleading:** The class name says "Blackwell" but it uses zero Blackwell tensor acceleration. Anyone reading the codebase would assume this is the production kernel. It's a stale early-exploration kernel superseded by `FmhaKernel`.
**Action:** Delete `BlackwellSparseDecodeKernel` and its CuTeDSL code. Keep `_fallback_sparse_sdp` as a reference implementation (rename to `_reference_sparse_sdp_attention`). The FMHA kernel in `dsv4/kernels/attention/fmha.py` is the real path. Do this cleanup as part of E7.
**What:** mHC mixing operations (`A_l @ X_l`, `B_l @ X_l`, `C_l ⊗ F_out`) use `torch.bmm` with tiny `n_hc=4` inner dimension.
**Why it matters:** For decode (T=1) this is fine — tiny matmul. For prefill it leaves throughput on the floor. But prefill is not the immediate priority.
**Status:** Lowest priority of the soft spots. Track for Stage G (prefill optimization).
---
## Stage D Build Order (REVISED)
### Priority Principle: Correctness First, Then Performance
D1 (hd=512) and D5 (SWA+sink merge) are both correctness-critical. But D5 depends on D1 (can't merge SWA if the kernel can't even run at hd=512). CG-1 (SwiGLU clamping) is a 2-line fix with no dependencies — do it first.
### D0 — SwiGLU Clamping Fix (CG-1) ⚡ DO THIS FIRST
- [ ] Add clamping to fused SwiGLU in `dsv4/kernels/gemm/fused_swiglu.py`
- [ ] Gate subtile: `silu_result = cute.math.fmin(silu_result, swiglu_limit)` after SiLU compute
- [ ] Up subtile: `acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit)` before gate*up multiply
- [ ] Verify: `cute.math.fmin` / `cute.math.fmax` work with CuTeDSL vectorized code (they should — they're elementwise)
- [ ] Test: fused MoE output matches reference with clamping at swiglu_limit=10.0
- [ ] Commit with clear message: "fix: add SwiGLU clamping to fused kernel (paper §4.2.3)"
**Why D6 is no longer a separate stage:** Designing SMEM-P around BF16 KV and then retrofitting FP4 is the detour trap. FP4 KV at hd=512 shrinks each KV tile 4× vs BF16, which changes the fundamental pipeline depth (kv_stage 2→4-6) and SMEM budget. The kernel we ship will run with FP4 KV — so plan for that architecture now.
**Paper §2.3.4:** KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The paged cache already implements this split (`entries_fp8` + `entries_rope` + `inv_scale`). For FMHA, we take it further: TMA loads FP4 (or FP8) KV to SMEM, dequantize on-the-fly in the SMEM→register path, then MMA.
**FP4 KV pipeline depth win:** At BF16 hd=512, one K tile = 128 × 512 × 2 = 128 KB. 2 stages = 512 KB (K+V). At FP4 (with FP8 scale overhead): ~36 KB per K tile, same SMEM supports 6+ stages. Each extra stage hides more TMA latency. At 1M-context decode, deeper stages matter a lot.
The indexer needs a full rewrite from scalar CUDA to tcgen05 MMA + radix-select. This is a major work item (2-3 weeks) that is out of scope for Stage D.
**Reference:** DeepGEMM's `fp8_mqa_logits` / `fp8_paged_mqa_logits` (Sept 2025 PR for V3.2 indexer). Our V4 variant: same pattern with FP4 inputs and tcgen05 MMA.
Three honest buckets. A fourth speculative bucket flagged at the end.
### NVFP4-0: Verify Right Blackwell FP4 Primitives ⚡ DO FIRST
**No correctness or quality risk. Pure correctness of implementation.** If these are wrong, we're running wrong MMA shapes silently.
#### NVFP4-0.1 — sf_dtype tracing
**What:** Trace the SF dtype through the full pipeline: `gemm_runner.py` → `dense.py` → `blockscaled_utils` → TMEM layout.
**The problem:** `dense.py` line 137 says NVF4 supports `Float8E8M0FNU/Float8E4M3FN` at sf_vec_size=16. But UE8M0 is the MXFP4/MXFP8 scale format. NVFP4 uses **FP8 E4M3**. The examples on lines 90/100 show `Float8E8M0FNU` at sf_vec_size=16 which is the MXFP4 path. **Need to verify the runner is passing E4M3, not E8M0.**
Action:
- [ ] Print `sf_dtype` in `gemm_runner.py` at construction: `print(f"sf_dtype={sf_dtype}, sf_vec_size={SF_VEC_SIZE}")`
- [ ] Print `self.sf_dtype` in `dense.py``BlockScaledGEMM.__init__`
- [ ] Print `self.sf_vec_size` in `dense.py`
- [ ] Trace through `blockscaled_utils.make_sm100_sf_layout` — does it produce E4M3 packing (4 FP8 E4M3 → 1 int32) or UE8M0 packing?
- [ ]**If wrong sf_dtype is found:** fix in `gemm_runner.py` SF_DTYPE constant, retest MoE cosine
#### NVFP4-0.2 — SF TMEM layout verification
**What:** NVFP4 expects scale factors in TMEM in a specific transposed-packed layout. UE4M3 for NVFP4 (4 packed FP8 E4M3 per int32 word). The comment in `dense.py` about "SM100 requires scaling factors in packed UE8M0 format" is for **MXFP8**, not NVFP4.
**What:** `float4_e2m1fn_x2` must survive all the way into TMA descriptor creation. Blackstone TMA supports `e2m1_x2` packed-FP4 element type directly. Loading as `uint8` works but loses tensor-core awareness.
Action:
- [ ] Trace `float4_e2m1fn_x2` through `quantize.py` → TMA atom creation in `fmha.py`
- [ ] Print the GMEM tensor dtype at FMHA kernel input
- [ ] Print the TMA atom dtype at construction
- [ ] Verify `cpasync.tma_partition` receives `float4_e2m1fn_x2` element type, not uint8
- [ ]**If uint8 fallback:** fix TMA atom creation in `fmha.py`
#### NVFP4-0.4 — MMA kind is mxf4nvf4
**What:** Blackwell has a single MMA kind for both MXFP4 and NVFP4. NVFP4 = scales are FP8 E4M3, 16-element block. MXFP4 = scales are UE8M0, 32-element block. The MMA kind is determined by scale-factor type at runtime. Need to confirm tcgen05 is inferring NVFP4.
Action:
- [ ] Print `tcgen05.mma.kind` at GEMM construction (if accessible)
- [ ] Print the MMA instruction shape `(M, N, K)` confirmed by JIT compile
- [ ] Verify it matches Blackwell MMA shape for NVFP4 (not MXFP4)
**Execution:** These are 5-minute print jobs. Do all 4 NVFP4-0 items before touching any code. If any of them reveals a wrong dtype, fix it FIRST before D1.3. A wrong sf_dtype poisons every FP4 GEMM result.
---
### NVFP4-1: Eliminate BF16 Round-Trips After FP4 GEMMs 🔴 PURE-WIN, NO QUALITY RISK
**These are pure bandwidth/compute wins. The math doesn't change — we just avoid precision loss and kernel launch overhead.**
The amax reduction is in-registers: for an epi tile with 16 contiguous elements per thread, each tile produces one FP8 E4M3 scale and 64 bits of packed FP4 nibbles. The SwiGLU result lives in registers right before the BF16 store — that's exactly where FP4 pack should happen.
**What you save:**
- ~2× GMEM bandwidth between L1 and L2 (FP4 instead of BF16)
4. Write packed nibbles to GMEM as `float4_e2m1fn_x2`
5. Write FP8 scale to SF TMA buffer
**The amax subtlety:** For NVFP4 the microblock is 16 elements. Port the same 16-element logic from `quantize.py` into the epilogue. Do NOT use 32-element MXFP4 microblocks.
**Note:** NVFP4-1.2 and NVFP4-1.3 depend on D1.5 (correction epilogue fix) because those epilogues need the clean one-way TMEM path. NVFP4-1.1 (MoE SwiGLU) is independent.
---
### NVFP4-2: FP4 KV Pipeline Depth in FMHA 🔴 STAGE D, DEPENDS ON D1.3
**FP4 KV shrinks tiles 4×, same SMEM budget buys 3× more pipeline stages.**
- [ ]**Prerequisite:** D1.3 (SMEM-P) working at BF16 first. Cannot skip.
- [ ]**Test:** FP4+BF16 split input → identical output to pure BF16 input (dequant is transparent)
---
### NVFP4-3: use_2cta_instrs for Production MoE 🟢 30 MINUTES, PURE PERF
**This is the single biggest single-knob perf win for FP4 GEMMs on B200.**
**What:** `FusedSwiGLUScaledGroupedGemmKernel` supports 2-CTA UMMA but defaults to `False`. With 2-CTA, the B operand is TMA-multicast: each CTA reads half of B, peers cross the Infiniband link. Effective MMA tile M doubles (128→256, 256→512).
**Measured win:** 1.7–1.9× throughput over single-CTA at prefill/batch shapes.
**Decision tree:**
- M < 128 (decode single-token): 1-CTA is correct. 2-CTA wastes hardware.
- M ≥ 256 (prefill or batched decode): 2-CTA is free perf.
- [ ]**Test:** throughput comparison at M=256, 512, 1024
- [ ]**Scope:** MoE-side, `gemm_runner.py`. Does not affect FMHA.
---
### ⚠️ Speculative: Beyond V4 Paper Validation
The following are real potential wins but go beyond what the V4 paper explicitly validated for FP4. Listed for completeness, do NOT implement without explicit sign-off from Mike.
**These are NOT on the roadmap until validated:**
1.**Indexer FP4 tensor-core scoring (paper §5.2.1 "QK path in the indexer... cached, loaded, and multiplied entirely in FP4")**
- Paper says the indexer SHOULD do QK in FP4 with tensor cores
- Current: scalar FP32 dot products with no tensor cores (PS-1)
- This is the paper's intended design, but it requires the full DeepGEMM `fp8_paged_mqa_logits` port to FP4 inputs
- Huge scope: 2-3 weeks minimum
- **Risk:** FP4 dot product precision for index selection needs recall validation. Paper says 99.7% recall with FP32→BF16 score quant — not the scoring itself.
- **Verdict:** Track for Stage F. Do NVFP4-0.4 first to ensure FP4 MMA is producing correct results, then re-evaluate.
2.**MXFP4 vs NVFP4 for indexer scoring**
- Paper §5.2.1: "QK activations cached" in FP4
- But also: MXFP4 (UE8M0 scales, 32-element blocks) has better numerical range than NVFP4
- For the indexer where recall > precision, MXFP4 may be the better choice
- Not validated in the paper for indexer scoring specifically
- **Verdict:** Evaluate after PS-1 rewrite. Do NVFP4-0 first.
3.**NVFP4 for full attention Q×K^T GEMM**
- We already know NVFP4 Q×K^T is too lossy (cosine 0.86 vs FP32 reference)
- This is NOT coming back regardless of speculation
| NVFP4-2 | FP4 KV pipeline depth | FMHA | NONE — perf only | D1.3 | 1 day |
**NVFP4-0 results gate the critical path.** If NVFP4-0.1–0.4 find a wrong sf_dtype or wrong MMA kind, the fix comes before D1.3. Everything else is either parallel or post-D1.3.
**Root Cause:** QK C-fragment tiling and PV A-operand tiling are fundamentally different. A tiled copy operation expects source and destination to have same tiling pattern.
**Status:** STUCK — Manual addressing harder than expected due to CuTeDSL JIT constraints.
**Problems Encountered:**
1.`cute.coord` doesn't exist — can't get thread's logical coordinates
2. Array indexing requires compile-time constants or vectorized loops
3. Layouts are completely different:
- TMEM P layout: `((128,128),1,1):((65536,1),0,0)`
- SMEM P layout: `((128,16),1,(4,2),1):((64,1),0,(16,8192),0)`
4. No clear mapping from TMEM coordinates to SMEM coordinates
**Root Issue:** Manual layout conversion in CuTeDSL requires understanding coordinate systems and offset computation, which is complex without proper documentation/examples.
**Options:**
1. Continue trying to implement manual conversion (high risk, time-consuming)
2. Find existing example of layout conversion in codebase
3. Ask for more specific guidance on coordinate mapping
4. Try different approach: make PV read from TMEM with different layout
**Blocked:** Need coordinate mapping formula or example.