28 KiB
Stage D — Parameterized FMHA for DSV4
⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING
The Workflow (DO NOT SKIP STEPS)
- Edit code in
~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py— this is the ONLY file for the FMHA kernel. - Commit and push:
cd ~/dev/nvfp4-megamoe-kernel git add -A && git commit -m "description" && git push origin master - Pull on B200:
sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master" - Test on B200:
sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \ "cd /root/dsv4-nvfp4-workspace/kernel && source /root/dsv4-nvfp4-workspace/venv/bin/activate && python3 -c '...'" - Regression check: After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert.
The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION)
- NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
- NEVER delete or modify the test files in
tests/unit/. They are the regression oracle. - NEVER touch drivers, kernels, firmware, or system packages on the B200.
- CuTeDSL variables defined in
ifblocks are NOT visible in otherifblocks. Even compile-time constants. Define all variables unconditionally before any branching. - Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
p_cols_fp32usespv_mma_tiler[2](K-dim), NOTpv_mma_tiler[1](N-dim). We got this wrong twice.- PV A-operand major mode is
OperandMajorMode.Kfor TMEM-P. Nota_majorfrom Q. tOrP0uses 3-dim indexing(None, None, kb), NOT 4-dim(None, None, kb, 0). The 4th mode was already sliced away bytOrP_base[(None,None,None,0)].- After every P store to TMEM, call
cute.arch.fence_view_async_tmem_store(). Missing this produces NaN. - PRINT THE SHAPES. ALWAYS. Run
print(f"tensor: shape={cute.shape(tensor)}")inside@cute.kernelat trace time. Reasoning about layouts without evidence is how we waste days.
What We Have Now (Starting Point)
File: dsv4/kernels/attention/fmha.py
Class: FmhaKernel
State: Parameterized head_dim (D1.0 done). TMEM-P path works at hd=64 (cos 0.972537). SMEM-P path is a stub that zeros sP.
What it does:
- 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA)
- QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM
- O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip
- Epilogue: TMEM → SMEM → GMEM via TMA store
- SMEM-P flag wired (
use_smem_p), PV source switches between TMEM/SMEM, but register→SMEM copy not implemented
The Problem at hd>64
At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works.
At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts disagree — different threads own different columns. The register bridge can't write P in a layout that PV can read.
The fix: SMEM-P path. P goes through SMEM instead of TMEM:
- Softmax computes P in registers (QK C-fragment partition)
- Write P to SMEM using the
p_smem_slayout (PV A-operand SMEM layout) - MMA warp reads P from SMEM via
tCrP = pv_mma.make_fragment_A(sP) - PV GEMM uses
tcgen05.OperandSource.SMEMinstead ofOperandSource.TMEM
The SMEM rendezvous: SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address.
The missing piece (the D1 work): The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a TiledCopy that partitions threads by QK's C-fragment and targets the P SMEM layout.
# The correct approach:
store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) # NOT pv_mma!
# This gives threads partitioned by QK C-fragment, writing to the P SMEM layout
Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM.
TMEM Column Budget at hd=512
This MUST be calculated before writing a single line of SMEM-P code.
TMA tensor tensor core (TMEM) has 512 columns per CTA. Each column is 32 bits wide.
At hd=64 (TMEM-P path):
- S (QK acc): 128 cols FP32
- P (softmax output): 64 cols FP32 (=
pv_mma_tiler[2] * BF16_width / FP32_width= 128 * 16/32 = 64... wait, let me recalculate)p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width- pv_mma_tiler = (128, 64, 128). pv_mma_tiler[2] = 128
- p_cols_fp32 = 128 * 16 / 32 = 64
- P starts at offset 32 (after 32 unused cols? No, S is at 0 with 128 cols, P at offset 32 overlaps??)
- Actually:
tmem_p0_offset = 32means P starts at TMEM col 32. But S uses cols 0-127. P at 32 means they OVERLAP. This works because S is consumed before P is written (softmax reads S, then writes P to same TMEM region).
- After P:
o_after = max(s_cols=128, p_end=32+64=96) = 128.tmem_o0_offset = ((128 + 31) // 32) * 32 = 128 - O (PV acc):
find_tmem_tensor_col_offset(tOtO)at hd=64 ≈ 128 cols FP32 - Total: 128 (O offset) + 128 (O size) = 256 cols. Fits in 512. ✅
At hd=512 (SMEM-P path):
- P is NOT in TMEM. S and O share TMEM (sequential, not concurrent).
- S (QK acc): 128 cols FP32 (same as hd=64 — QK is always (128, 128))
- O (PV acc): at hd=512, PV is (128, 512). PV MMA C-fragment is (128, 512) FP32 = 512 cols? NO.
tOtO = pv_thr.make_fragment_C(pv_as)wherepv_as = pv_thr.partition_shape_C((128, 512))- The C-fragment for a tcgen05 MMA with shape (128, 512) in FP32:
- M=128 → 4 warps × 32 threads = 128 rows, each thread owns 1 row
- N=512 → 512/32 = 16 TMEM columns per thread? No, tcgen05 MMA writes (32, 32) tiles.
- For (128, 512) MMA: 4 M-tiles × 16 N-tiles = 64 (32×32) subtiles
- Each subtile uses 32 TMEM columns. But they're distributed across warps.
find_tmem_tensor_col_offset(tOtO)gives the actual footprint.
- MUST PRINT THIS ON THE B200. Do not guess. Run a shape probe.
- If O needs ~512 cols: S (128) + O (512) = 640 > 512. DOES NOT FIT.
- Fix options:
- Drop
kv_stagefrom 2 to 1 — frees SMEM but loses K/V double-buffering. TMEM budget unchanged. - Split O into halves: process (128, 256) PV twice, each O tile is 256 cols. S(128) + O(256) = 384 < 512. ✅
- Process S and O sequentially: after softmax consumes S, O can reuse S's TMEM region. O at offset 0, 512 cols. Total = 512. ✅ But only if we don't need S anymore when writing O (true — softmax is done before PV starts per KV tile).
- Drop
- Fix options:
Plan: SMEM-P path reuses S's TMEM for O. After softmax reads S and writes P to SMEM, S's TMEM region (cols 0-127) is dead. PV writes O starting at col 0. O at hd=512 needs ~256-512 cols (must measure). If O fits in cols 0-511 with S gone, we're golden.
Action item: Run shape probe on B200 before coding SMEM-P at hd=512.
# Shape probe script to run on B200:
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.nvgpu.tcgen05 as tcgen05
from cutlass import BFloat16, Float32, LayoutEnum
a_major = LayoutEnum.ROW_MAJOR # adjust to match
b_major = LayoutEnum.ROW_MAJOR
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,512), tcgen05.OperandSource.SMEM)
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C((128, 512))
tOtO = pv_thr.make_fragment_C(pv_as)
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
o_cols = find_tmem_tensor_col_offset(tOtO)
print(f"hd=512 PV C-fragment: pv_as={pv_as}, tOtO.layout={tOtO.layout}, o_cols={o_cols}")
# Also print tOtO shape
print(f"tOtO shape: {cute.shape(tOtO)}")
Correctness Gaps — Must Close Before Production
These are NOT optimization gaps. They are cases where the current code produces numerically wrong outputs vs the trained checkpoint.
CG-1: SwiGLU Clamping Missing from Fused Kernel ⚠️ CRITICAL
What: FusedSwiGLUScaledGroupedGemmKernel.__init__ stores self.swiglu_limit but the SwiGLU compute block (lines 2185–2200 in fused_swiglu.py) never references it. The reference path in dsv4/reference/moe_pipeline.py correctly applies clamp(max=swiglu_limit) to gate and clamp(min=-limit, max=+limit) to up. The fused kernel silently skips it.
Why it matters: Paper §4.2.3 explicitly says weights were trained with the gate component capped at 10 and the linear component clamped to [−10, 10]. Without clamping, the fused kernel produces different outputs than the reference at large activation values.
Fix (2 lines in the fused kernel):
# After computing silu_result (gate subtile):
silu_result = cute.math.fmin(silu_result, swiglu_limit)
# Before the gate*up multiply (up subtile):
acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit)
Where: dsv4/kernels/gemm/fused_swiglu.py, lines ~2192 (gate branch) and ~2198 (up branch).
Status: 🔴 NOT FIXED. Do this IMMEDIATELY — it's a 2-line fix and affects all MoE layer outputs.
CG-2: FMHA at hd=512 SMEM-P is a Stub ⚠️ CRITICAL
What: FmhaKernel with use_smem_p=True zeros sP and comments "PV will produce garbage." DSV4 head_dim is 512 (§4.2.1). The kernel literally cannot produce correct output at the production head dimension.
Why it matters: This is the D1 work. The path forward is correct (make_tiled_copy_C(store_atom, qk_mma) to partition P registers for SMEM staging). But TMEM column budget at hd=512 must be verified first (see budget section above).
Status: 🔴 D1.2–D1.3 TODO. This document IS the plan.
CG-3: SWA + Sink Merge Not Fused in FMHA ⚠️ CRITICAL
What: The DSV4 attention design (§2.3.3) requires merging compressed top-k attention with sliding window attention via sink weights. Currently the FMHA kernel only does dense attention (one KV source, one softmax, one PV, normalize). The sink merge is implemented in Python fallback (decode_sparse.py) but NOT in the production kernel path.
Why it matters: Without SWA+sink merge, the compressed branch alone cannot capture local dependencies. The paper is explicit: "Additional Branch of Sliding Window Attention." Every CSA and HCA layer produces wrong output without this.
Fix plan (D5, ordered by priority):
- D5a: Emit un-normalized
o+lseinstead of normalizedo. This is the SINGLE MOST IMPORTANT structural change — once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness. Keepnormalizeas a flag so standalone tests still work. - D5b: Run kernel twice externally (compressed_kv + swa_kv), merge in Python. End-to-end correctness without touching kernel structure. This is the correctness baseline.
- D5c: Fuse two passes into one kernel launch (Q stays in SMEM, two sequential MMA loops). Pure optimization.
- D5d: Fuse sink merge into kernel epilogue. Pure optimization.
Status: 🔴 D5 TODO. D5a must be done FIRST — it unblocks D5b which gives us correctness.
CG-4: Inverse RoPE Verification ⚠️ HIGH
What: inverse_rope_bf16 in dsv4/ops/rope.py applies the conjugate rotation to the last rope_dim=64 dims of each head output. The math looks correct: inv[2i] = x[2i] * cos + x[2i+1] * sin, inv[2i+1] = -x[2i] * sin + x[2i+1] * cos. This is the standard inverse rotation for interleaved (GPT-J) RoPE.
What needs verifying:
- The
positionsargument must be the same positions used for the forward RoPE on Q and K. The inverse RoPE applies RoPE with position = +position (not -position). The "inverse" is the conjugate rotation, not a negated angle. The code usescos_sin_cache[positions, :]which is the same table as forward RoPE. For conjugate rotation, we need cos(θ) and sin(θ) at the SAME position, then flip the sign on the sin terms in the odd positions. The current code does this correctly:inv_odd = -o_even * sin_all + o_odd * cos_all. ✅ - The
nope_dim=448/rope_dim=64split must match the model's actual split. If a layer uses a different split, the inverse RoPE would rotate the wrong dims. - The cos_sin_cache must be the same cache used for forward RoPE. If there's any offset or indexing difference, the angles won't match.
Action: Write a unit test that: (1) applies forward RoPE to random input, (2) applies inverse RoPE, (3) verifies the result matches the original. This is a round-trip test and catches both sign and indexing errors.
Status: 🟡 Code looks correct but UNTESTED. Add a round-trip unit test.
CG-5: Mixed-Precision KV (BF16 RoPE + FP8 NoPE) — FMHA Load Path ⚠️ HIGH
What: Paper §2.3.4: KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The PagedKVPool already implements this split: entries_fp8 (uint8) + entries_rope (BF16) + inv_scale (FP32). The current decode_sparse.py fallback dequantizes in Python before calling the kernel.
Why it matters for FMHA: The FmhaKernel currently takes contiguous BF16 K/V tensors. At production, the kernel must handle the mixed-precision KV directly — reading FP8 + BF16 from the paged cache and dequantizing on the fly during TMA→SMEM transfer. This is the proper Blackwell pattern: TMA loads FP8 to SMEM, on-the-fly dequant in the SMEM→register path, then MMA.
The proper approach:
- TMA loads FP8 NoPE dims to SMEM slot 0
- TMA loads BF16 RoPE dims to SMEM slot 1 (or separate TMA)
- Dequantize FP8 → BF16 in SMEM (vectorized, per-entry
inv_scalemultiply) - Concatenate [NoPE, RoPE] in SMEM (or use two separate SMEM regions with strided MMA)
- MMA reads contiguous BF16 from SMEM
Prerequisite: This requires D1 (SMEM-P) and D5 (sink merge) to be working first. The mixed-precision load path replaces the current "all BF16" K/V input with the real paged cache format.
Status: 🔴 NOT IMPLEMENTED. Plan as D6 (after D5). The current test harness passes contiguous BF16 K/V, which is fine for correctness testing. The FP8 dequant in SMEM is a performance + memory optimization that doesn't affect numerical correctness (FP8 dequant is well-defined).
CG-6: Per-Token valid_lens in Indexer for Prefill ⚠️ MEDIUM
What: score_topk.py has a TODO that broadcasts request 0's valid_lens for prefill (T > B). For batched prefill, different requests have different numbers of compressed entries in the pool. Broadcasting the first request's count means other requests either score garbage entries (too many) or miss valid ones (too few).
Why it matters: Prefill correctness blocker. The indexer will select wrong entries for all requests except the first in a batch.
Fix: Map each query token to its request ID, then look up valid_lens[request_id]. The request_ids: [T] int32 tensor already exists in the cache handle. The indexer kernel needs this as an input.
Status: 🔴 NOT FIXED. This is indexer scope, not FMHA scope. Track separately.
Performance Soft Spots — Important But Not Correctness
These affect throughput but not numerical correctness. Tracked for Stage F+.
PS-1: Indexer Score+TopK is Scalar CUDA — Not Blackwell Native 🔴
What: indexer_score_topk.cu is a CUDA-core scalar implementation:
- Triple loop:
for h in n_heads, for g in n_groups, for b in 8 - FP4 nibble dequant to FP32, FP32 dot product
- Shared-memory min-heap protected by single
s_lockatomicCAS spinlock - For 1M-context: ~250K compressed entries scored per query token
Why it's the biggest perf leak: The dot products should use tensor cores. The heap spinlock won't scale to top_k=1024 with hundreds of thousands of candidates.
The correct approach: DeepGEMM's fp8_mqa_logits / fp8_paged_mqa_logits pattern (Sept 2025 PR for V3.2 indexer). Weighted ReLU MQA logits computed with tensor cores, paged variant for decode. Our V4 NVFP4 variant should be that pattern with FP4 inputs and tcgen05 MMA. Beyond the MMA, the heap needs replacing with per-warp partial top-k merged via reduction tree, or radix-select.
Status: Tracked for Stage F (post Stage E). Not blocking D1–D5.
PS-2: decode_sparse.py BlackwellSparseDecodeKernel is Misleading 🟡
What: dsv4/ops/decode_sparse.py contains BlackwellSparseDecodeKernel — a CuTeDSL kernel that does scalar for d in range(HD): dot += q_val * k_val with no tensor cores. It also has a _fallback_sparse_sdp Python path that uses F.scaled_dot_product_attention.
Why it's misleading: The class name says "Blackwell" but it uses zero Blackwell tensor acceleration. Anyone reading the codebase would assume this is the production kernel. It's a stale early-exploration kernel superseded by FmhaKernel.
Action: Delete BlackwellSparseDecodeKernel and its CuTeDSL code. Keep _fallback_sparse_sdp as a reference implementation (rename to _reference_sparse_sdp_attention). The FMHA kernel in dsv4/kernels/attention/fmha.py is the real path. Do this cleanup as part of E7.
Status: Low urgency. Track for E7 cleanup.
PS-3: mHC Mixing Uses torch.bmm with n_hc=4 🟢
What: mHC mixing operations (A_l @ X_l, B_l @ X_l, C_l ⊗ F_out) use torch.bmm with tiny n_hc=4 inner dimension.
Why it matters: For decode (T=1) this is fine — tiny matmul. For prefill it leaves throughput on the floor. But prefill is not the immediate priority.
Status: Lowest priority of the soft spots. Track for Stage G (prefill optimization).
Stage D Build Order (REVISED)
Priority Principle: Correctness First, Then Performance
D1 (hd=512) and D5 (SWA+sink merge) are both correctness-critical. But D5 depends on D1 (can't merge SWA if the kernel can't even run at hd=512). CG-1 (SwiGLU clamping) is a 2-line fix with no dependencies — do it first.
D0 — SwiGLU Clamping Fix (CG-1) ⚡ DO THIS FIRST
- Add clamping to fused SwiGLU in
dsv4/kernels/gemm/fused_swiglu.py - Gate subtile:
silu_result = cute.math.fmin(silu_result, swiglu_limit)after SiLU compute - Up subtile:
acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit)before gate*up multiply - Verify:
cute.math.fmin/cute.math.fmaxwork with CuTeDSL vectorized code (they should — they're elementwise) - Test: fused MoE output matches reference with clamping at swiglu_limit=10.0
- Commit with clear message: "fix: add SwiGLU clamping to fused kernel (paper §4.2.3)"
D1 — Parameterized HEAD_DIM + SMEM-P (CG-2)
D1.0 — Replace HEAD_DIM constant with constructor parameter ✅ DONE
Already in the kernel. head_dim is a constructor arg. TMEM-P path works at hd=64.
D1.1 — Add SMEM-P path behind use_smem_p flag ✅ WIRED (stub)
The use_smem_p flag exists. PV source switches between TMEM/SMEM. TMEM layout adjusts. But the register→SMEM copy is a stub that zeros sP.
D1.2 — TMEM Column Budget Verification 🔨 DO THIS BEFORE CODING
- Run shape probe on B200:
find_tmem_tensor_col_offset(tOtO)at hd=512 - Print
pv_as,tOtO.layout,o_colsat hd=128, 256, 512 - Calculate: can S(128) and O(???) share TMEM at hd=512?
- If O > 384 cols: plan for split-PV (two (128, 256) passes)
- Document the budget numbers HERE in this file
D1.3 — Implement register→SMEM copy for P (THE HARD PART)
- Build
tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)— QK MMA partitions threads - Print the shapes:
cute.shape(tiled_p_copy), partition source/dest shapes - Partition
sPwithtiled_p_copyas destination - In softmax warps: after computing P in registers, write to SMEM via
tiled_p_copy - Add
p_smem_ready_barNamedBarrier: softmax arrives after write + fence, MMA waits before PV GEMM - In MMA warp: read P from SMEM via
tCrP = pv_mma.make_fragment_A(sP) - Test: hd=64, n=128,
use_smem_p=True→ compare against TMEM-P result - Test: hd=128, n=128 → test against FP32 oracle
- Test: hd=256, n=128 → test against FP32 oracle
D1.4 — Multi-PV-tile for hd>256
- Add
pv_n_tile = min(head_dim, 256)andn_pv_tiles = head_dim // pv_n_tileto__init__ - For hd=512: 2 PV tiles of (128, 256) each
- Strategy: kernel processes one PV N-tile per launch. Python orchestrates the tiles.
- Pass 0: V[:, 0:256] → output[:, 0:256], QK + softmax + PV for cols 0-256
- Pass 1: V[:, 256:512] → output[:, 256:512], QK + softmax + PV for cols 256-512
- QK and softmax run identically both passes (P is the same). Only PV changes.
- Alternative (if SMEM-P allows): keep P in SMEM between PV tiles. Run QK+softmax once, PV twice.
- Test: hd=512, n=128 → correct output against FP32 oracle
D1.5 — Correction Epilogue: Fix TMEM Layout Mismatch (3% Error)
The current TMEM round-trip (Ld32x32bOp + St32x32bOp hand-constructed atoms) introduces 3% error at hd=64 (cos 0.973). The proper fix is the CUTLASS correction_epilog pattern:
TMEM --get_tmem_load_op--> reg (normalize + FP32→BF16) --get_smem_store_op--> SMEM --TMA--> GMEM
This is a one-way trip. No TMEM round-trip. No layout mismatch.
- Investigate: can we use
get_tmem_load_op+get_smem_store_oppaired atoms? - Investigate: can we inject
inv_row_sumintoepilogue_tma_storepipeline? - Investigate: pre-compute TMA partitioning outside
if warp_idxblocks (region isolation workaround) - Test: hd=64, n=128 → cos should jump from 0.973 → ~0.9999
- Test: hd=64, n=256 → cos should jump from 0.793 → ~0.9999
Note: This is NOT blocking for D2–D5. The 3% error is a precision issue, not a correctness issue (the attention math is right, the epilogue just introduces rounding). Fix it properly rather than hacking it. But don't let it block the D2–D5 pipeline.
D2 — Multi-Query Grid with Head Packing
- Grid changes from
(1, 1, 1)to(num_q_blocks, 1, batch) - DSV4 is MQA: all 128 query heads share same K/V
- Head axis folded into M dimension of Q tile:
M_tile = 128coversM = T * n_hrows - At decode T=1: M = 1 × 128 = 128 — one Q block covers all heads. ✅
- At prefill T=64: M = 64 × 128 = 8192 — 64 Q blocks. Needs grid loop.
- Test: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention
D3 — SWA Sequence Length Mask
- Add
swa_lens: [batch] int32kernel input - Mask SWA-branch logits to
-infwhereswa_idx >= swa_lens[b] - Test: batched input with varying SWA fill levels (position 50 vs 5000)
D4 — Causal Mask on SWA Branch
- Add
is_causal: boolconstructor flag - Apply
swa_idx > q_posmasking to-infin SWA pass - Main path has NO mask (indexer enforces causality upstream)
- Test: prefill mode produces correct output with causal mask
D5 — SWA + Sink Merge (CG-3) ⚠️ THE WHOLE POINT OF V4 ATTENTION
D5a — Emit un-normalized o + lse ⚡ DO THIS IMMEDIATELY AFTER D1
This is the single most important structural change. Once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness.
- Change epilogue: instead of
O *= 1/row_sum, emitOun-normalized andlse = log(row_sum) + row_maxas a separate output - Add
normalize: boolconstructor flag (default: True for backward compat, False for merge mode) - When
normalize=False: skip the TMEM round-trip for normalize. O stays asPV @ V(un-normalized). lse written to a separate GMEM buffer. - Test:
normalize=True→ identical to current behavior (regression) - Test:
normalize=False→o_unnorm / exp(lse).unsqueeze(-1)≈o_normalized(verify math)
D5b — Python merge (correctness baseline)
- Run FmhaKernel twice: once with compressed_kv, once with swa_kv
- Merge in Python:
exp_lse_sparse = lse_sparse.exp() exp_lse_swa = lse_swa.exp() exp_sink = sink_logits.exp() o = (exp_lse_sparse * o_sparse + exp_sink * exp_lse_swa * o_swa) / (exp_lse_sparse + exp_sink * exp_lse_swa) - Test against FP32 oracle that does sparse+SWA+sink merge
- This gives us end-to-end correctness. Everything after is optimization.
D5c — Fuse two passes into one kernel launch
- Q loaded once to SMEM, used by both compressed and SWA MMA loops
- Two sequential QK→softmax→PV passes in one kernel invocation
- K/V have two sources: compressed (contiguous BF16) and SWA (from cache)
- For now: dequantize SWA in a small prep kernel before FMHA, FMHA sees two contiguous BF16 sources
- [ Test: output matches D5b Python merge
D5d — Fuse sink merge into kernel epilogue
- TMEM holds two O accumulators + two row_max/row_sum per row
- Verify TMEM column budget: two O + two (row_max, row_sum) at hd=512
- Sink merge in TMEM:
O = (exp(lse1) * O1 + exp(sink) * exp(lse2) * O2) / (exp(lse1) + exp(sink) * exp(lse2)) - [ Test: output matches D5b Python merge
D6 — Mixed-Precision KV Load Path (CG-5)
- TMA loads FP8 NoPE dims to SMEM slot 0
- TMA loads BF16 RoPE dims to SMEM slot 1
- Dequantize FP8 → BF16 in SMEM (vectorized
* inv_scale) - Concatenate [NoPE, RoPE] in SMEM
- MMA reads contiguous BF16 from SMEM
- Test: FP8+BF16 split input matches pure BF16 input (dequant is transparent)
- Prerequisite: D1 (SMEM-P) and D5 (sink merge) working first
Inverse RoPE Verification (CG-4) — Separate from D1–D6
- Write unit test:
tests/unit/test_inverse_rope.py - Round-trip test: forward RoPE → inverse RoPE → verify ≈ original
- Multi-head test: verify only last 64 dims are rotated
- Position test: verify cos_sin_cache indexing is correct for positions > 0
- This is a standalone test, not a kernel change. Can be done anytime.
Per-Token valid_lens (CG-6) — Indexer Scope, Not FMHA
- Add
request_ids: [T] int32to indexer kernel input - Look up
valid_lens[request_ids[t]]per query token - Replace the current broadcast of
valid_lens[:1] - Test: batched prefill with different sequence lengths per request
- Tracked separately from FMHA Stage D. This is indexer work.
Correctness Gap NOT in This Project
CG-7: Indexer Rewrite (PS-1) — Stage F
The indexer needs a full rewrite from scalar CUDA to tcgen05 MMA + radix-select. This is a major work item (2-3 weeks) that is out of scope for Stage D.
Reference: DeepGEMM's fp8_mqa_logits / fp8_paged_mqa_logits (Sept 2025 PR for V3.2 indexer). Our V4 variant: same pattern with FP4 inputs and tcgen05 MMA.
Execution Order (Top to Bottom)
| # | Task | Blocks | Est. |
|---|---|---|---|
| D0 | SwiGLU clamping (CG-1) | Nothing — do first | 30 min |
| D1.2 | TMEM budget probe at hd=512 | D1.3 | 1 hr |
| D1.3 | Register→SMEM copy for P | D1.4, D2 | 1-2 days |
| D1.4 | Multi-PV-tile hd>256 | D2 | 1 day |
| D1.5 | Correction epilog fix (3% → 0.01%) | Nothing (can parallel) | 1-2 days |
| D2 | Multi-query grid + head packing | D3 | 1 day |
| D3 | SWA sequence length mask | D5 | ½ day |
| D4 | Causal mask on SWA | D5 | ½ day |
| D5a | Emit un-normalized o + lse | D5b | 1 day |
| D5b | Python merge (correctness) | D5c | ½ day |
| D5c | Fuse two passes in one launch | D5d | 2 days |
| D5d | Fuse sink merge in epilogue | D6 | 2 days |
| D6 | Mixed-precision KV load | E1 | 2 days |
| CG-4 | Inverse RoPE round-trip test | Nothing | 2 hrs |
| CG-6 | Per-token valid_lens (indexer) | Nothing | ½ day |
Critical path: D0 → D1.2 → D1.3 → D1.4 → D5a → D5b (end-to-end correctness)
D1.5 (correction epilog) and CG-4 (RoPE test) can happen in parallel with D2–D4.