DSV4 Inference Kernel
⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️
THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.
After cpasync.tma_partition(), the output GMEM tensor has 4 modes (verified on B200):
tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
mode 0 1 2 3
Mode 2 is the GMEM tile dimension. The dimension you index with kt to load different K/V tiles.
THE WRONG WAY (what we did — silently loads from tile 0 forever):
# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
# TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)] # ← WRONG! Mode 2 pinned to 0!
# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
cute.copy(tma_k, tBgK[(None, kv_coord)], ...) # ← kv_coord indexes wrong mode!
THE RIGHT WAY (verified on B200 at n=128 and n=256):
# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
# Mode 2 (KV tiles) survives as the second mode.
tBgK = tBgK[(None, 0, None, 0)]
# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
cute.copy(tma_k, tBgK[None, kt], ...)
# ^^ THIS IS THE KV TILE DIM
Verified shapes on B200 (May 22, n=256, inside @cute.kernel):
Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?)) — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?)) — 2 modes
WHY THIS IS SO INSIDIOUS
- No error, no warning. The slice
tBgK[(None,None,0,0)]silently sets mode 2 to 0. - Single-tile (n=128) works perfectly. With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
- Multi-tile tests produce "reasonable" output. The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
- The strides are all 0. Printing
tBgK.layout.strideshows all zeros for TMA tensors. You can't detect the bug from strides alone. cute.printfshowskv_coord=0. We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.- The 8-mode theory was wrong. We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.
THE LESSON
PRINT THE SHAPES. ALWAYS. Run print(f"tBgK: shape={cute.shape(tBgK)}") inside @cute.kernel at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.
The correct pre-slice depends on which mode is the GMEM tile iteration axis. For our local_tile + partition_B + group_modes(0,3) pattern, mode 2 is the KV tile axis. (None,0,None,0) keeps it free. (None,None,0,0) collapses it to 0.
# ALWAYS verify the shape at trace time:
print(f"tBgK shape: {cute.shape(tBgK)}") # 4 modes
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}") # 2 modes
# Then index the 2D tensor:
cute.copy(tma_k, tBgK[None, kt], ...)
IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.
Architecture
DSV4 is not MLA. It uses CSA (Compressed Sparse Attention, m=4) and HCA (Heavily Compressed Attention, m′=128). KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.
DSV4 inference pipeline — component status
==========================================
Legend:
[✓] built and tested
[~] partial — reference or seam exists, native pending
[✗] to build
┌────────────────────────────────────┐
│ [✗] Embedding + mHC init │
│ token embed + n_hc=4 streams │
└────────────────┬───────────────────┘
│
▼
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 0–1 of Pro, alternating CSA / HCA after │
│ │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [~] Norms + RoPE RMSNorm + partial RoPE │ │
│ │ [✓] Q / KV projection NVFP4 linears + LoRA │ │
│ │ [✓] Token compressor CSA m=4 / HCA m′=128 │ │
│ │ [✓] Indexer + top-k CSA, FP32 dot + top-k │ │
│ │ [~] FMHA core QK → online softmax → PV │ │
│ │ + SWA branch + sink merge │ │
│ │ [✓] Output projection inv RoPE + wo_a grouped + wo_b │ │
│ └────────────────────────────────────────────────────────────────┘ │
│ │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐ │
│ │ [✓] Residual mHC pre + post mix │ │
│ │ [✓] Pre-FFN norm RMSNorm │ │
│ │ [✓] Router sqrt(softplus) + topk + hash │ │
│ │ [✓] Routed MoE fused SwiGLU L1 + L2 │ │
│ │ [✓] Shared expert NVFP4 single-group GEMM │ │
│ └────────────────────────────────────────────────────────────────┘ │
└──────────────────────────────────┬───────────────────────────────────┘
│
▼
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘
┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management │
│ • state cache: SWA window + uncompressed tail per layer │
│ • classical paged cache: lcm(m, m′) = 128 tokens per block │
│ • heterogeneous layout per layer │
└──────────────────────────────────────────────────────────────────────┘
Summary
-------
Built [✓] : 9 — mHC ×2, Q/KV proj, output proj, routed MoE,
shared expert, token compressor, indexer+topk,
router, pre-FFN norm
Partial [~] : 3 — norms+RoPE, FMHA core
To build [✗] : 6 — embedding+init, final norm, LM head, MTP, sampler, KV cache
Status (May 24, 2026 — 21:30 UTC)
| Stage | Status | Description |
|---|---|---|
| A | ✅ COMPLETE | Q@K^T via tcgen05.mma → TMEM → GMEM |
| B | ✅ COMPLETE | QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving) |
| C | ✅ COMPLETE | Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to dsv4/kernels/attention/fmha.py as FmhaKernel. |
| D1 | 🟡 hd≤256 DONE | Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. |
| D2 | TODO | Multi-query grid with head packing (128 Q heads, MQA) |
| D3 | TODO | SWA sequence length mask (swa_lens per batch) |
| D4 | TODO | Causal mask on SWA branch only |
| D5 | 🟢 D5a+D5b DONE | D5a: normalize flag + LSE output (err=0.0). D5b: Python SWA+sink merge (cos 0.961). D5c/D5d: fused kernel merge TODO. |
| E1-E7 | TODO | Production extraction (class, custom op, cache, cleanup) |
Package Structure
dsv4/
├── kernels/ Pure GPU code (CuTeDSL @cute.jit, .cu files)
│ ├── gemm/ NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│ ├── attention/ FMHA kernel — FmhaKernel (hd=64, TMEM-P proven; SMEM-P stub for hd>64)
│ ├── compressor/ CSA/HCA token-level compressor (CuTeDSL, 419 lines)
│ ├── indexer/ CSA indexer — score+topk (FP32 dot products, top-k selection)
│ ├── router/ Dense router decode kernel (warp-specialized persistent GEMM)
│ ├── cache/ Cache kernels — append_swa (write KV to split state cache layout)
│ ├── decode/ Decode-time attention (sparse, SWA — future)
│ └── cuda/ Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/ PyTorch ↔ kernel bridges
│ ├── quantize.py BF16 ↔ NVFP4 conversion, scale factors
│ ├── layouts.py Scale swizzle, gate/up interleave, K-major, offsets
│ ├── gemm_runner.py Warmup, compile, run grouped/fused GEMMs
│ ├── custom_ops.py torch.library.custom_op registrations
│ ├── decode_sparse.py native_sparse_decode dispatcher
│ ├── decode_swa.py native_swa_decode dispatcher
│ ├── rope.py Forward + inverse RoPE
│ ├── topk.py Python wrapper for sparse_topk_metadata.cu
│ ├── topk_select.py Top-k selection wrapper
│ └── router.py Router op bridge
├── layers/ nn.Module-style components
│ ├── linear.py Nvfp4Linear
│ ├── grouped_linear.py Nvfp4GroupedLinear
│ ├── moe.py Nvfp4MoE
│ ├── shared_expert.py Nvfp4SharedExpert
│ ├── mhc.py mHCLayer
│ ├── attention.py DSV4 attention sub-block (CSA/HCA/SWA variants, 245 lines)
│ ├── norm.py RMSNorm (PyTorch ref, fused kernel later)
│ ├── router.py Router — token-to-expert assignment (273 lines)
│ ├── embedding.py Token embedding + mHC init (stub)
│ └── ffn.py FFN sub-block
├── model/ Model assembly
│ ├── config.py Model config
│ ├── layer.py Transformer layer
│ ├── layer_schedule.py Layer scheduling
│ ├── mtp.py Multi-token prediction
│ ├── sampler.py Token sampler
│ └── dsv4.py Full model (stub — Phase 1)
├── cache/ KV cache infra
│ ├── allocator.py Cache memory allocator
│ ├── block_table.py Paged cache block table
│ ├── flush.py Cache flush
│ ├── handle.py Cache handle
│ ├── manager.py Cache manager
│ ├── paged_cache.py Paged KV cache
│ ├── prepare_forward.py Forward prep
│ ├── schema.py Cache schema
│ └── state_cache.py State cache (SWA ring buffer)
├── loader/ Checkpoint I/O
│ ├── hf_checkpoint.py HuggingFace checkpoint loader
│ └── layout_convert.py Weight layout conversion
└── reference/ Slow PyTorch oracles (never imported by production code)
├── attention.py RoPE, KV cache, causal attention, SWA
├── csa_attention.py CSA/HCA sparse attention
├── compressor.py Compressor PyTorch example
└── moe_pipeline.py MoE pipeline reference
Mental model: kernels/ → ops/ → layers/ → model/ (dependency flows left to right). reference/ and loader/ are sidecars.
Active Test Files
FMHA (Stages A/B/C/D1) — in tests/unit/
| File | Stage | Status |
|---|---|---|
test_fmha_v3.py |
A+B | ✅ Full QK→identity softmax→PV, cosine 0.999999 |
test_fmha_v3_12w.py |
A+B | ✅ 12-warp QK→PV, cosine 0.999999 |
test_fmha_v3_stage_c.py |
C | ✅ Real online softmax + normalize, n=128 cos 0.973. Also in module as FmhaKernel. |
test_fmha_v3_stage_d1.py |
D1 | ✅ hd=64/128/256 PASS (cos 0.999998, TMEM-P). hd=512 SMEM overflow. |
test_fmha_v3_stage_d5b.py |
D5b | ✅ Python SWA+sink merge (cos 0.961, LSE err=0.0) |
test_d1_*.py |
D1 | 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug) |
test_paired_epilog.py |
C | ✅ Paired atom epilogue experiments |
test_pv64_with_softmax.py |
B | ✅ (128,64) PV, single AB pipeline |
test_128_128_vdiag.py |
A+B | ✅ (128,128) PV baseline |
test_qkonly.py |
A | ✅ QK with split Q/KV pipelines |
test_qk_softmax.py |
A+B | ✅ QK + identity softmax, no PV |
MoE / GEMM — in tests/unit/
| File | What |
|---|---|
test_cutedsl.py |
NVFP4 grouped GEMM kernel |
cudagraph_test.py |
Cudagraph capture + replay |
layertest.py |
Per-layer correctness |
test_custom_op.py |
torch.library custom ops |
test_compile_custom_op.py |
Compile + warmup |
test_fp4_roundtrip.py |
BF16 → NVFP4 → BF16 roundtrip |
test_interleave.py |
Gate/up weight interleaving |
test_interleave_gemm.py |
Interleaved GEMM correctness |
test_fused_step1.py |
Fused SwiGLU GEMM |
Test Harness
Scripts in tests/ for running tests on the B200 (root@45.76.247.107):
run_test.sh — Run a test in a screen session
# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py
What it does:
- Kills any existing
kernel-testscreen and SIGKILLs all child processes (handles deadlocked GPU procs that ignore SIGHUP) - Deletes the old log file
- Starts a new
screen -dmS kernel-testrunning the test - Logs output to
/tmp/kernel-test.log - Verifies the screen started
check_log.sh — Check test progress
bash tests/check_log.sh
Shows the log contents and whether the screen is still running.
Local → B200 workflow
# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push
# 2. SSH to B200, pull, run
ssh root@45.76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py
# 3. Check results
bash tests/check_log.sh
fire_b200_test — One-command local test runner
Lives in ~/.openclaw/workspace/fire_b200_test (NOT in the repo — project-specific tooling).
# From your local machine, one command to push, run, and get results:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py
What it does:
- Auto-commits and pushes any local changes
- SSH to B200, pulls, starts
run_test.shin a screen - Polls every 15s until the screen exits
- Dumps the full test log to your terminal
This is strictly for the DSV4 NVFP4 kernel project. It hardcodes the B200 IP, repo paths, and git remote.
Stage C: Online Softmax — TMEM Layout Mismatch Issue
Current Results (test_fmha_v3_stage_c.py)
| n | cos | Status |
|---|---|---|
| 128 | 0.973 | ⚠️ 3% error from TMEM layout mismatch |
| 256 | 0.793 | ⚠️ Two TMEM round-trips compound the error |
| 384+ | N/A | Pipeline doesn't cycle past 2 KV tiles |
Root Cause: TMEM Layout Mismatch
The MMA instruction writes O to TMEM using the C-fragment layout. The epilogue_tma_store helper reads O from TMEM using get_tmem_load_op, which uses the correct C-fragment-compatible layout. Raw PV output is perfect (cos 0.999998) when epilogue_tma_store reads directly without any round-trip.
The problem appears when we do a TMEM round-trip (load O → modify → store back) using hand-constructed Ld32x32bOp/St32x32bOp atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error.
Fix proven but not yet integrated: The epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition pattern from CUTLASS's cutlass.utils.gemm.sm100 reads O from TMEM using the correct get_tmem_load_op layout and writes to SMEM using get_smem_store_op. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper flat_divide and tma_partition handling inside the kernel's warp-specific if blocks.
Key Bug Fix: tOrP0 TMEM Column Offset (May 23)
The softmax warps store P at tmem_p0_offset=32 FP32 columns (64 BF16 elements). PV MMA must read from the same offset. tOrP0 was missing this offset, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with:
if const_expr(self.tOrP0_offset > 0):
tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
tOrP0 = tOrP
Must use const_expr conditional (not Python if) because CuTeDSL compiles both branches, and tOrP.iterator + 0 fails with MLIR type error.
Architecture (6-warp, current)
Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)
TMEM Layout
Col 0-31: S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95: P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view)
Col 128+: O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))
Remaining for Multi-Tile Production
- Fix TMEM layout mismatch — replace hand-constructed atom round-trips with correction_epilog pattern
- Pipeline state cycling for n≥384 — kv_stage=2 can only buffer 2 tiles
- 12-warp layout — separate softmax/correction/epilogue warps
- O rescale for kt > 0 — must also use paired atoms or correction_epilog
CuTeDSL Constraints (hard-won)
vectorize=Trueloops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry.reduce(cute.ReductionOp.MAX): reduces ENTIRE C-fragment to scalar — global max, not per-rowcute.arch.fmax: impure for vectorizer — use plainrange()loop- TMA partition tensors have 4 modes:
(((64,128),1), ?, KV_tiles, ?)—(None,0,None,0)keeps mode 2 (KV tiles) free,[None, kt]indexes it tBgK[(None, None, 0, 0)]pins mode 2 to 0 — silently reads tile 0 forever. Use(None,0,None,0)instead.softmax_done_barNamedBarrier is reusable across tiles- Hand-constructed TMEM atoms corrupt data on round-trip:
Ld32x32bOp+St32x32bOpbuilt independently introduce ~3% error. Useget_tmem_load_op+get_smem_store_oppaired atoms for one-way trips. - CuTeDSL region isolation:
flat_divideandtma_partitioncan't be called insideif warp_idxblocks. Do partitioning outsideifblocks or in regular (non-@cute.kernel) helper functions. compositionvslogical_divide: Both re-tile a tensor, but produce different layouts. The CUTLASScorrection_rescaleusescomposition,correction_epiloguseslogical_divide. The copy atoms must match the tensor layout they were created with.- Variables in CuTeDSL
ifblocks are NOT visible in otherifblocks. Even when the condition is a compile-time constant (self.use_smem_p), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined unconditionally before the firstifthat uses them. This applies acrossif warp_idx == Xblocks,forloops, and nested branches. If a variable is set inif not use_smem_p:and read in anotherif not use_smem_p:inside aforloop inside anif warp_idx < mma_warp_id:, it won't be visible. Define all such variables before any branching. tOrP0MUST include thetmem_p0_offsetcolumn offset. The softmax warps store P attmem_p0_offset=32(FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Useconst_exprconditional:if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP. Cannot usetOrP.iterator + 0(MLIR OpResult + int fails).- LSE formula:
lse = ln(row_sum) + row_max * ln(2).row_maxis in the scale_log2 domain (max(S * scale * log2(e))). Multiply byln(2)to convert to natural log domain:attn_max = row_max * ln(2). Solse = ln(row_sum) + row_max * ln(2). Verified: LSE err=0.000000. - CuTeDSL MLIR backend cannot handle complex pipeline loops. The MLIR→PTX optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both Python
range()(unrolled) andcutlass.range(unroll=1)(runtime) trigger 3+ hour compilation for hd=512. Consider raw CUDA C++ for complex kernels. Pre-compilation + cubin caching is a viable workaround if the optimizer eventually finishes. - Guard dead code with
const_expr. CuTeDSL compiles BOTH branches of Pythonifstatements. Useconst_expr(condition)to eliminate dead code at compile time. Critical for: O rescale (only when n_kv_tiles>1), LSE (only when normalize=False), SMEM-P path (only when use_smem_p=True), k_sub path (only when n_k_sub_tiles>1). - External k_sub merge is mathematically impossible. k_sub segments are additive in LOGIT space (S = S_0 + S_1), not attention weight space. You cannot recover softmax(S_0+S_1)@V from softmax(S_0)@V and softmax(S_1)@V. The D5 merge formula works for different token sets (additive in weight space), NOT for partial dot products. In-kernel k_sub accumulation before softmax is the only correct approach.
pv_n_tilereduction is the easiest SMEM knob. At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. Cost: 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck, and this is simpler than SMEM overlap or Q tiling.
Key Lessons
- NEVER use
find_tmem_tensor_col_offset()as TMEM placement. It returns footprint size, not a safe offset. - FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe.
- TMEM allocation must be power of 2.
- Square hides bugs. (128,128) worked for every wrong approach. Always test non-square.
- St32x32bOp MUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access.
- First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.
- FMHA P store uses QK C-fragment composition, NOT PV A-fragment. Two aliases, same TMEM.
- Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip this.
- PRINT THE SHAPES. ALWAYS. Reasoning about TMEM layouts without evidence is how we waste days.
- Never assume TMEM round-trips are safe. Verify with NO-OP tests before adding logic.
Stage D: Full Decode Attention (revised May 23)
Key Insight: The Indexer Solves Paging Upstream
The indexer now hands the kernel selected_kv: [T, top_k, head_dim] BF16 — a dense, materialized, dequantized K/V tile. FMHA sees a dense [T, top_k, 512] tile, exactly like Stage A/B's existing k and v inputs. The kernel doesn't need to know it's sparse. Paged TMA, scattered HBM reads, FP8 dequantization — all handled by gather_selected_kv upstream.
The SWA branch is the only "irregular" thing: it reads from the state cache's ring buffer with a position mask. SWA is small (n_win=128 per query), so it's a separate fused branch with a sink-weighted merge.
One FMHA kernel serves all three DSV4 attention types:
- CSA:
compressed_kv= top-k from indexer,swa_kvfrom cache → sink merge - HCA:
compressed_kv= all classical pool entries (gather-all mode),swa_kvfrom cache → sink merge - SWA-only (Flash layers 0-1):
compressed_kv= empty (top_k=0), only SWA runs. Sink merge degenerates to justo_swaafter renormalization.
Build Order
D1 — Parameterize HEAD_DIM + SMEM-P (~1 day, MOSTLY DONE)
Currently hardcoded at 64. Promote to constructor arg, thread through _setup. Test at 64, then 512 (DSV4's real value).
hd≤256: ✅ DONE. cos 0.999998 at hd=64/128/256. Both TMEM-P and SMEM-P paths work.
hd=512: ❌ BLOCKED. SMEM budget fixed (192KB, fits 232KB limit). Kernel structurally correct (tracer 0.8s). But CuTeDSL's MLIR→PTX backend optimizer hangs for 3+ hours when compiling the k_sub loop. External k_sub merge is mathematically impossible (k_sub segments additive in logit space, not weight space). Need either: (a) pre-compile offline + cache cubin, (b) add no-softmax mode for S accumulation in Python, or (c) write hd=512 path in raw CUDA C++.
Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.
D2 — Multi-query grid with head packing (~1 day)
Grid changes from (1, 1, 1) to (num_q_blocks, 1, batch). DSV4 is MQA — all n_h=128 query heads share the same K/V. The query-head axis is folded into the M dimension of the Q tile: M_tile = 128 covers M = T * n_h rows. At decode T is small (1-16), so packing heads into M fills the MMA. At prefill T=64, M is already 8192 with heads packed.
Done when: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention against FP32 oracle.
D3 — SWA sequence length mask (~½ day)
The indexer's top_k is fixed (512 for Flash, 1024 for Pro). Compressed-K input is always [T, top_k, head_dim] with the same top_k at compile time.
What varies: the SWA window holds up to n_win=128 tokens but starts with fewer. Add swa_lens: [batch] int32 as kernel input. Mask SWA-branch logits to -inf where swa_idx >= swa_lens[b].
Done when: batched input with varying SWA fill levels (some requests at position 50, some at 5000) produces correct masked output.
D4 — Causal mask on SWA branch (~½ day)
The compressed K the indexer selects is already from s < floor(t/m) (paper eq. 17). The indexer enforces causality at selection time. FMHA sees only causally-valid candidates. The main path has no mask.
The SWA branch needs a causal mask within the window. Add is_causal: bool constructor flag, apply swa_idx > q_pos masking to -inf in the SWA pass.
Done when: prefill mode produces correct output with the causal mask applied to SWA.
D5 — SWA + sink merge (~2-3 days) ← D5a+D5b DONE (May 23), D5c/D5d remaining
Per dsv4/ops/decode_sparse.py:
o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
With un-normalized O (D5a): o_unnorm = o_norm * exp(lse), so:
o = (o_unnorm_sparse + exp(attn_sink) * o_unnorm_swa)
/ (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))
D5a DONE (May 23): normalize flag added to FmhaKernel. When False, emits un-normalized O + LSE. LSE formula: lse = ln(row_sum) + row_max * ln(2) (row_max in scale_log2 domain, multiply by ln(2) to convert). LSE err=0.000000 verified.
D5b DONE (May 23): Python SWA+sink merge works end-to-end at hd=64. Run FMHA twice (compressed KV + SWA KV, normalize=False), merge in Python. Merge cos 0.961, individual attention cos 0.963/0.960.
Sub-steps remaining:
- 5c: Fuse the two passes into one kernel launch. Q stays in SMEM, two MMA loops sequentially.
- 5d: Fuse the merge into the kernel epilogue.
Done when: end-to-end kernel produces correct attention against FP32 oracle that does sparse+SWA+sink merge.
D5 (old) paged TMA — REMOVED. The indexer + gather handles all paging upstream.
Kernel Architecture (after D5)
Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512]
swa_lens [batch], sink_logits [n_h], request_ids [T]
│
├─ Load Q to SMEM (once)
│
├─ Loop 1: compressed KV (top_k tokens)
│ QK → online softmax → PV → O_sparse, lse_sparse in TMEM
│
├─ Loop 2: SWA window (n_win tokens, masked by swa_lens)
│ QK → online softmax → PV → O_swa, lse_swa in TMEM
│
└─ Sink merge epilogue:
O = (exp(lse_sparse) * O_sparse + exp(sink) * exp(lse_swa) * O_swa)
/ (exp(lse_sparse) + exp(sink) * exp(lse_swa))
Reference Files
- Sink merge spec:
dsv4/ops/decode_sparse.py(formula) - SWA decode:
dsv4/ops/decode_swa.py - Attention reference:
dsv4/reference/attention.py - CSA attention:
dsv4/reference/csa_attention.py
Stage C Note
When implementing D5a, Stage C's epilogue changes from "multiply by 1/row_sum" to "emit un-normalized o + lse". Defer this until D5. Through D1-D4, keep Stage C normalize as-is and test as standalone dense FMHA.
Stage E: Production Extraction (revised May 23)
E1 — File placement
dsv4/kernels/attention/fmha.py. Currently contains FmhaKernel (migrated from test, hd=64 TMEM-P). Will gain parameterized head_dim and SMEM-P path in D1. Constructor takes all dimensions and dtypes, no module-level constants.
E2 — Constructor signature
class FmhaKernel:
def __init__(
self,
head_dim: int, # 512 for DSV4
num_query_heads: int, # 128 for Pro, 64 for Flash
sliding_window: int, # 128
top_k: int, # 512 (Flash) or 1024 (Pro)
q_dtype=BFloat16,
kv_dtype=BFloat16,
o_dtype=BFloat16,
qk_acc_dtype=Float32,
pv_acc_dtype=Float32,
is_causal: bool = False, # affects SWA mask only
cta_group: tcgen05.CtaGroup = tcgen05.CtaGroup.ONE,
cluster_shape_mn: tuple = (1, 1),
):
All architecture-level shapes from config flow into the constructor. No FMHA-internal magic numbers.
E3 — Call signature
def __call__(
self,
q: torch.Tensor, # [T, n_h, head_dim] BF16
compressed_kv: torch.Tensor, # [T, top_k, head_dim] BF16 — from indexer gather
swa_kv: torch.Tensor, # [batch, n_win, head_dim] BF16 — from cache prep
swa_lens: torch.Tensor, # [batch] int32
sink_logits: torch.Tensor, # [n_h] FP32
request_ids: torch.Tensor, # [T] int32 — maps query to its SWA slot
o: torch.Tensor, # [T, n_h, head_dim] BF16 — preallocated
stream: cuda.CUstream,
):
Notably absent: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream.
E4 — Kernel cache + warmup
Mirror dsv4/ops/gemm_runner.py's _compiled_kernel_cache. Key on (head_dim, num_query_heads, top_k, is_causal, ...). Pre-allocate at warmup, reuse at call. For DSV4, the cache has at most ~2 entries (Flash/Pro × causal/non).
E5 — torch.library custom op
@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))
def sparse_fmha_with_swa_op(
q: torch.Tensor,
compressed_kv: torch.Tensor,
swa_kv: torch.Tensor,
swa_lens: torch.Tensor,
sink_logits: torch.Tensor,
request_ids: torch.Tensor,
o: torch.Tensor,
runner_id: int,
) -> None:
runner = get_runner(runner_id)
runner._run_impl(q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o)
Mutates o (preallocated buffer). Consistent with cudagraphs.
E6 — Reference parity hook
dsv4/reference/attention.py stays as the FP32 oracle. New test: tests/unit/test_fmha_kernel.py.
def test_sparse_fmha_matches_spec(T=64, n_h=128, top_k=1024, n_win=128, hd=512):
q = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device='cuda')
ck = torch.randn(T, top_k, hd, dtype=torch.bfloat16, device='cuda')
swa = torch.randn(4, n_win, hd, dtype=torch.bf16, device='cuda')
swa_lens = torch.tensor([128, 50, 128, 75], dtype=torch.int32)
sink = torch.randn(n_h, device='cuda')
req_ids = torch.randint(0, 4, (T,), dtype=torch.int32)
# Oracle: pure FP32 spec
o_sparse, lse_sparse = attention_with_lse_f32(q, ck, ck)
o_swa, lse_swa = attention_swa_with_lse_f32(q, swa, swa, swa_lens, req_ids)
e_sink = sink.exp()
num = lse_sparse.exp().unsqueeze(-1) * o_sparse \
+ e_sink[None, :, None] * lse_swa.exp().unsqueeze(-1) * o_swa
den = lse_sparse.exp() + e_sink[None, :] * lse_swa.exp()
expected = num / den.unsqueeze(-1)
# Kernel
o = torch.empty_like(expected, dtype=torch.bfloat16)
fmha = FmhaKernel(head_dim=hd, num_query_heads=n_h, sliding_window=n_win, top_k=top_k)
fmha(q, ck, swa, swa_lens, sink, req_ids, o, stream=...)
torch.testing.assert_close(o.float(), expected, atol=5e-3, rtol=5e-3)
E7 — Cleanup
Delete all debug test files. test_fmha_v3.py becomes dsv4/kernels/attention/fmha.py. Only tests/unit/test_fmha_kernel.py remains as the attention test.
Environment
- Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
- venv:
source /root/dsv4-nvfp4-workspace/venv/bin/activate - PYTHONPATH:
/root/dsv4-nvfp4-workspace/kernel - Model:
/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4 - vLLM repo:
/root/dsv4-nvfp4-workspace/vllm(modified for Blackwell) - CUTLASS FMHA reference:
/root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py - Local CUTLASS clone:
/home/openclaw/dev/cutlass