Files
biondizzle 43f0b5d1e8 D1.5: Fix O rescale with paired atoms (incremental approach)
Keep epilogue_tma_store for final output (proven path).
Only fix the multi-KV-tile O rescale using paired atoms from
epilogue_tmem_copy_and_partition. The paired atoms share addressing,
making the TMEM->REGS->modify->TMEM cycle lossless.

Guarded by const_expr(n_kv_tiles > 1) so single-tile path (n=128)
is completely unaffected — zero regression risk.

Full correction epilogue (one-way TMEM->REGS->SMEM->GMEM) deferred
until we can address the MLIR compilation time issue.
2026-05-26 19:34:26 +00:00

37 KiB
Raw Permalink Blame History

DSV4 Inference Kernel

⚠️⚠️⚠️ CRITICAL: TMA Partition Tensor Mode Ordering ⚠️⚠️⚠️

THIS BUG COST US AN ENTIRE DAY. READ THIS. BURN IT INTO YOUR BRAIN.

After cpasync.tma_partition(), the output GMEM tensor has 4 modes (verified on B200):

tBgK shape: (((64, 128), 1), ?, KV_tiles, ?)
                 mode 0      1  2        3

Mode 2 is the GMEM tile dimension. The dimension you index with kt to load different K/V tiles.

THE WRONG WAY (what we did — silently loads from tile 0 forever):

# ❌❌❌ (None,None,0,0) KEEPS MODES 0,1 FREE, SETS MODE 2 TO 0 ❌❌❌
# Mode 2 (the KV tile dim) gets collapsed to coordinate 0.
# TMA ALWAYS reads from tile 0.
tBgK = tBgK[(None, None, 0, 0)]  # ← WRONG! Mode 2 pinned to 0!

# The copy "works" but kv_coord indexes mode 1 (inner GEMM K, not KV tiles).
cute.copy(tma_k, tBgK[(None, kv_coord)], ...)  # ← kv_coord indexes wrong mode!

THE RIGHT WAY (verified on B200 at n=128 and n=256):

# ✅ (None,0,None,0) keeps modes 0 and 2 free → 2D tensor
# Mode 2 (KV tiles) survives as the second mode.
tBgK = tBgK[(None, 0, None, 0)]

# ✅ [None, kt] indexes the surviving mode 1 (originally mode 2 = KV tiles)
cute.copy(tma_k, tBgK[None, kt], ...)
#                       ^^ THIS IS THE KV TILE DIM

Verified shapes on B200 (May 22, n=256, inside @cute.kernel):

Before slice: tBgK = (((64,128),1), Int32(?), Int32(?), Int32(?))  — 4 modes
After (None,0,None,0): tBgK = (((64,128),1), Int32(?))             — 2 modes

WHY THIS IS SO INSIDIOUS

  1. No error, no warning. The slice tBgK[(None,None,0,0)] silently sets mode 2 to 0.
  2. Single-tile (n=128) works perfectly. With only 1 KV tile, mode 2 is size 1, so the bug is invisible.
  3. Multi-tile tests produce "reasonable" output. The TMA loads from tile 0 every time, so you get a valid (but wrong) attention computation. Cosine similarity is 0.7-0.9, not NaN.
  4. The strides are all 0. Printing tBgK.layout.stride shows all zeros for TMA tensors. You can't detect the bug from strides alone.
  5. cute.printf shows kv_coord=0. We thought the JIT was constant-folding the variable. It wasn't — the variable was fine, but it was indexing the wrong mode.
  6. The 8-mode theory was wrong. We assumed tma_partition produced 8 TMA coordinate dimensions. It produces 4. The 8-None no-op slice fails with "weakly congruent" at JIT compile.

THE LESSON

PRINT THE SHAPES. ALWAYS. Run print(f"tBgK: shape={cute.shape(tBgK)}") inside @cute.kernel at trace time. The shapes are your ground truth. Reasoning about mode counts without evidence is how we wasted a day.

The correct pre-slice depends on which mode is the GMEM tile iteration axis. For our local_tile + partition_B + group_modes(0,3) pattern, mode 2 is the KV tile axis. (None,0,None,0) keeps it free. (None,None,0,0) collapses it to 0.

# ALWAYS verify the shape at trace time:
print(f"tBgK shape: {cute.shape(tBgK)}")  # 4 modes
print(f"tBgK after slice: {cute.shape(tBgK[(None,0,None,0)])}")  # 2 modes

# Then index the 2D tensor:
cute.copy(tma_k, tBgK[None, kt], ...)

IF YOU USE (None,None,0,0) INSTEAD OF (None,0,None,0), MULTI-TILE TMA WILL BE SILENTLY BROKEN.


Architecture

DSV4 is not MLA. It uses CSA (Compressed Sparse Attention, m=4) and HCA (Heavily Compressed Attention, m=128). KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.

DSV4 inference pipeline — component status
==========================================

Legend:
 [✓] built and tested
 [~] partial — reference or seam exists, native pending
 [✗] to build


 ┌────────────────────────────────────┐
 │ [✗] Embedding + mHC init          │
 │ token embed + n_hc=4 streams      │
 └────────────────┬───────────────────┘
                  │
                  ▼
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 01 of Pro, alternating CSA / HCA after              │
│                                                                      │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [~] Norms + RoPE             RMSNorm + partial RoPE           │  │
│ │ [✓] Q / KV projection        NVFP4 linears + LoRA             │  │
│ │ [✓] Token compressor         CSA m=4 / HCA m=128             │  │
│ │ [✓] Indexer + top-k          CSA, FP32 dot + top-k             │  │
│ │ [~] FMHA core                QK → online softmax → PV         │  │
│ │                              + SWA branch + sink merge         │  │
│ │ [✓] Output projection        inv RoPE + wo_a grouped + wo_b   │  │
│ └────────────────────────────────────────────────────────────────┘  │
│                                                                      │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [✓] Pre-FFN norm              RMSNorm                          │  │
│ │ [✓] Router                    sqrt(softplus) + topk + hash     │  │
│ │ [✓] Routed MoE               fused SwiGLU L1 + L2             │  │
│ │ [✓] Shared expert            NVFP4 single-group GEMM          │  │
│ └────────────────────────────────────────────────────────────────┘  │
└──────────────────────────────────┬───────────────────────────────────┘
                                  │
                                  ▼
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘

┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management                                             │
│ • state cache: SWA window + uncompressed tail per layer             │
│ • classical paged cache: lcm(m, m) = 128 tokens per block         │
│ • heterogeneous layout per layer                                    │
└──────────────────────────────────────────────────────────────────────┘


Summary
-------
 Built  [✓] : 9 — mHC ×2, Q/KV proj, output proj, routed MoE,
               shared expert, token compressor, indexer+topk,
               router, pre-FFN norm
 Partial [~] : 3 — norms+RoPE, FMHA core
 To build [✗] : 6 — embedding+init, final norm, LM head, MTP, sampler, KV cache

Status (May 26, 2026 — 18:40 UTC)

Stage Status Description
A COMPLETE Q@K^T via tcgen05.mma → TMEM → GMEM
B COMPLETE QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving)
C COMPLETE Real online softmax. Kernel outputs un-norm O + LSE (no TMEM round-trip). Migrated to dsv4/kernels/attention/fmha.py as FmhaKernel.
D1 🟡 hd≤256 DONE Parameterized HEAD_DIM. qk_mma_tiler fix (hd=64/128/256 cos 0.999998). hd=512 SMEM fits but MLIR compilation hangs (>3hr). External k_sub merge proven impossible. O rescale TMEM round-trip BROKEN (Ld32x32bOp/St32x32bOp corrupt data). Python KV merge workaround works.
D1.5 BLOCKER O rescale for multi-KV-tile (kt>0). TMEM round-trip corruption (even NO-OP round-trip fails). Python KV merge workaround: cos 0.999994. Production: 5-9 kernel launches per decode. Fix requires correction epilog (one-way TMEM→regs→SMEM→GMEM).
D2 🟡 Per-head DONE Head-packed M-dimension launch (cos 0.999995, n_h=1-128). Multi-CTA grid blocked: flat_divide + epilogue_tma_store layout mismatch.
D3 DONE SWA sequence length mask (in-kernel post-QK via tTMEM_LOADcS coordinates, swa_len Int32 scalar, offset by n_comp for D5c)
D4 DONE Causal mask on SWA branch (SWA-relative position > m_coord → -inf, combined with D3 via OR logic)
D5 D5a+D5b+D5c DONE D5a: normalize flag + LSE + row_sums output. D5b: Per-row LSE + Python KV merge (cos 0.999994). D5c: Sink bias as logit modification — mathematically equivalent to separate merge, single pass over combined KV (cos 0.999996 single-tile AND multi-tile). D5d (fused in-kernel merge) NOT NEEDED — sink bias approach supersedes it.
E1-E7 TODO Production extraction (class, custom op, cache, cleanup)
NVFP4-3 DONE use_2cta_instrs conditional in gemm_runner.py. 1.7-1.9× throughput at prefill shapes.

Package Structure

dsv4/
├── kernels/          Pure GPU code (CuTeDSL @cute.jit, .cu files)
│   ├── gemm/           NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│   ├── attention/      FMHA kernel — FmhaKernel (hd=64, TMEM-P proven; SMEM-P stub for hd>64)
│   ├── compressor/     CSA/HCA token-level compressor (CuTeDSL, 419 lines)
│   ├── indexer/        CSA indexer — score+topk (FP32 dot products, top-k selection)
│   ├── router/         Dense router decode kernel (warp-specialized persistent GEMM)
│   ├── cache/          Cache kernels — append_swa (write KV to split state cache layout)
│   ├── decode/         Decode-time attention (sparse, SWA — future)
│   └── cuda/           Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/              PyTorch ↔ kernel bridges
│   ├── quantize.py      BF16 ↔ NVFP4 conversion, scale factors
│   ├── layouts.py       Scale swizzle, gate/up interleave, K-major, offsets
│   ├── gemm_runner.py   Warmup, compile, run grouped/fused GEMMs
│   ├── custom_ops.py    torch.library.custom_op registrations
│   ├── decode_sparse.py native_sparse_decode dispatcher
│   ├── decode_swa.py    native_swa_decode dispatcher
│   ├── rope.py          Forward + inverse RoPE
│   ├── topk.py          Python wrapper for sparse_topk_metadata.cu
│   ├── topk_select.py   Top-k selection wrapper
│   └── router.py        Router op bridge
├── layers/           nn.Module-style components
│   ├── linear.py        Nvfp4Linear
│   ├── grouped_linear.py Nvfp4GroupedLinear
│   ├── moe.py           Nvfp4MoE
│   ├── shared_expert.py Nvfp4SharedExpert
│   ├── mhc.py           mHCLayer
│   ├── attention.py     DSV4 attention sub-block (CSA/HCA/SWA variants, 245 lines)
│   ├── norm.py          RMSNorm (PyTorch ref, fused kernel later)
│   ├── router.py        Router — token-to-expert assignment (273 lines)
│   ├── embedding.py     Token embedding + mHC init (stub)
│   └── ffn.py           FFN sub-block
├── model/            Model assembly
│   ├── config.py        Model config
│   ├── layer.py         Transformer layer
│   ├── layer_schedule.py Layer scheduling
│   ├── mtp.py           Multi-token prediction
│   ├── sampler.py       Token sampler
│   └── dsv4.py          Full model (stub — Phase 1)
├── cache/            KV cache infra
│   ├── allocator.py     Cache memory allocator
│   ├── block_table.py   Paged cache block table
│   ├── flush.py         Cache flush
│   ├── handle.py        Cache handle
│   ├── manager.py       Cache manager
│   ├── paged_cache.py   Paged KV cache
│   ├── prepare_forward.py Forward prep
│   ├── schema.py        Cache schema
│   └── state_cache.py   State cache (SWA ring buffer)
├── loader/           Checkpoint I/O
│   ├── hf_checkpoint.py HuggingFace checkpoint loader
│   └── layout_convert.py Weight layout conversion
└── reference/        Slow PyTorch oracles (never imported by production code)
    ├── attention.py     RoPE, KV cache, causal attention, SWA
    ├── csa_attention.py CSA/HCA sparse attention
    ├── compressor.py    Compressor PyTorch example
    └── moe_pipeline.py  MoE pipeline reference

Mental model: kernels/ops/layers/model/ (dependency flows left to right). reference/ and loader/ are sidecars.


Active Test Files

FMHA (Stages A/B/C/D1) — in tests/unit/

File Stage Status
test_fmha_v3.py A+B Full QK→identity softmax→PV, cosine 0.999999
test_fmha_v3_12w.py A+B 12-warp QK→PV, cosine 0.999999
test_fmha_v3_stage_c.py C Real online softmax + normalize, n=128 cos 0.973. Also in module as FmhaKernel.
test_fmha_v3_stage_d1.py D1 hd=64/128/256 PASS (cos 0.999998, TMEM-P). hd=512 SMEM overflow.
test_fmha_v3_stage_d5b.py D5b Python SWA+sink merge (cos 0.999994, LSE err=0.0)
test_d5c_fused.py D5c Single-tile combined KV + sink bias (cos 0.999996)
test_d5c_multitile.py D5c Multi-tile with Python KV merge + sink bias (cos 0.999996)
test_d1_*.py D1 🔨 Debug/diagnostic variants (hd512, regression, sweep, raw, debug)
test_paired_epilog.py C Paired atom epilogue experiments
test_pv64_with_softmax.py B (128,64) PV, single AB pipeline
test_128_128_vdiag.py A+B (128,128) PV baseline
test_qkonly.py A QK with split Q/KV pipelines
test_qk_softmax.py A+B QK + identity softmax, no PV

MoE / GEMM — in tests/unit/

File What
test_cutedsl.py NVFP4 grouped GEMM kernel
cudagraph_test.py Cudagraph capture + replay
layertest.py Per-layer correctness
test_custom_op.py torch.library custom ops
test_compile_custom_op.py Compile + warmup
test_fp4_roundtrip.py BF16 → NVFP4 → BF16 roundtrip
test_interleave.py Gate/up weight interleaving
test_interleave_gemm.py Interleaved GEMM correctness
test_fused_step1.py Fused SwiGLU GEMM

Test Harness

Scripts in tests/ for running tests on the B200 (root@45.76.247.107):

run_test.sh — Run a test in a screen session

# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py

What it does:

  1. Kills any existing kernel-test screen and SIGKILLs all child processes (handles deadlocked GPU procs that ignore SIGHUP)
  2. Deletes the old log file
  3. Starts a new screen -dmS kernel-test running the test
  4. Logs output to /tmp/kernel-test.log
  5. Verifies the screen started

check_log.sh — Check test progress

bash tests/check_log.sh

Shows the log contents and whether the screen is still running.

Local → B200 workflow

# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push

# 2. SSH to B200, pull, run
ssh root@45.76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py

# 3. Check results
bash tests/check_log.sh

fire_b200_test — One-command local test runner

Lives in ~/.openclaw/workspace/fire_b200_test (NOT in the repo — project-specific tooling).

# From your local machine, one command to push, run, and get results:
~/.openclaw/workspace/fire_b200_test tests/unit/test_fmha_v3.py

What it does:

  1. Auto-commits and pushes any local changes
  2. SSH to B200, pulls, starts run_test.sh in a screen
  3. Polls every 15s until the screen exits
  4. Dumps the full test log to your terminal

This is strictly for the DSV4 NVFP4 kernel project. It hardcodes the B200 IP, repo paths, and git remote.


Stage C: Online Softmax — TMEM Layout Mismatch Issue

Current Results (test_fmha_v3_stage_c.py)

n cos Status
128 0.973 ⚠️ 3% error from TMEM layout mismatch
256 0.793 ⚠️ Two TMEM round-trips compound the error
384+ N/A Pipeline doesn't cycle past 2 KV tiles

Root Cause: TMEM Layout Mismatch

The MMA instruction writes O to TMEM using the C-fragment layout. The epilogue_tma_store helper reads O from TMEM using get_tmem_load_op, which uses the correct C-fragment-compatible layout. Raw PV output is perfect (cos 0.999998) when epilogue_tma_store reads directly without any round-trip.

The problem appears when we do a TMEM round-trip (load O → modify → store back) using hand-constructed Ld32x32bOp/St32x32bOp atoms. These atoms use a different column mapping than the MMA's C-fragment layout, causing ~3% data corruption per round-trip. Both the NO-OP round-trip (previously used to "fix" layout) and the normalize round-trip (multiply by 1/row_sum) suffer from this error.

Fix proven but not yet integrated: The epilogue_tmem_copy_and_partition + epilogue_smem_copy_and_partition pattern from CUTLASS's cutlass.utils.gemm.sm100 reads O from TMEM using the correct get_tmem_load_op layout and writes to SMEM using get_smem_store_op. This is a one-way trip (TMEM→reg→SMEM→GMEM) that eliminates the layout mismatch entirely. Integration requires proper flat_divide and tma_partition handling inside the kernel's warp-specific if blocks.

Key Bug Fix: tOrP0 TMEM Column Offset (May 23)

The softmax warps store P at tmem_p0_offset=32 FP32 columns (64 BF16 elements). PV MMA must read from the same offset. tOrP0 was missing this offset, causing PV to read from TMEM column 0 (where S is) instead of column 32 (where P is). This was the root cause of NaN/zeros in D1 tests. Fixed with:

if const_expr(self.tOrP0_offset > 0):
    tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout)
else:
    tOrP0 = tOrP

Must use const_expr conditional (not Python if) because CuTeDSL compiles both branches, and tOrP.iterator + 0 fails with MLIR type error.

Architecture (6-warp, current)

Warps 0-3: Softmax + Epilogue (row_max, row_sum, P store, O rescale, final normalize)
Warp 4: MMA (QK, PV)
Warp 5: TMA (Q/K/V load)

TMEM Layout

Col 0-31:   S (QK acc, 128 FP32 via Ld32x32bOp Repetition(32))
Col 32-95:  P (64 FP32 via St32x32bOp Repetition(32), register bridge BF16 view)
Col 128+:   O (PV acc, 64 FP32, rescale via Ld32x32bOp Repetition(16))

Remaining for Multi-Tile Production

  1. Fix TMEM layout mismatch — replace hand-constructed atom round-trips with correction_epilog pattern
  2. Pipeline state cycling for n≥384 — kv_stage=2 can only buffer 2 tiles
  3. 12-warp layout — separate softmax/correction/epilogue warps
  4. O rescale for kt > 0 — must also use paired atoms or correction_epilog

CuTeDSL Constraints (hard-won)

  1. vectorize=True loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry
  2. .reduce(cute.ReductionOp.MAX): reduces ENTIRE C-fragment to scalar — global max, not per-row
  3. cute.arch.fmax: impure for vectorizer — use plain range() loop
  4. TMA partition tensors have 4 modes: (((64,128),1), ?, KV_tiles, ?)(None,0,None,0) keeps mode 2 (KV tiles) free, [None, kt] indexes it
  5. tBgK[(None, None, 0, 0)] pins mode 2 to 0 — silently reads tile 0 forever. Use (None,0,None,0) instead.
  6. softmax_done_bar NamedBarrier is reusable across tiles
  7. Hand-constructed TMEM atoms corrupt data on round-trip: Ld32x32bOp + St32x32bOp built independently introduce ~3% error. Use get_tmem_load_op + get_smem_store_op paired atoms for one-way trips.
  8. CuTeDSL region isolation: flat_divide and tma_partition can't be called inside if warp_idx blocks. Do partitioning outside if blocks or in regular (non-@cute.kernel) helper functions.
  9. composition vs logical_divide: Both re-tile a tensor, but produce different layouts. The CUTLASS correction_rescale uses composition, correction_epilog uses logical_divide. The copy atoms must match the tensor layout they were created with.
  10. Variables in CuTeDSL if blocks are NOT visible in other if blocks. Even when the condition is a compile-time constant (self.use_smem_p), CuTeDSL's MLIR lowering creates separate regions. Variables must be defined unconditionally before the first if that uses them. This applies across if warp_idx == X blocks, for loops, and nested branches. If a variable is set in if not use_smem_p: and read in another if not use_smem_p: inside a for loop inside an if warp_idx < mma_warp_id:, it won't be visible. Define all such variables before any branching.
  11. tOrP0 MUST include the tmem_p0_offset column offset. The softmax warps store P at tmem_p0_offset=32 (FP32 columns = 64 BF16 elements). PV MMA must read from the same offset. Missing this causes NaN/zeros (MMA reads S from column 0, not P from column 32). Use const_expr conditional: if const_expr(self.tOrP0_offset > 0): tOrP0 = cute.make_tensor(tOrP.iterator + self.tOrP0_offset, tOrP.layout) else: tOrP0 = tOrP. Cannot use tOrP.iterator + 0 (MLIR OpResult + int fails).
  12. LSE formula: lse = ln(row_sum) + row_max * ln(2). row_max is in the scale_log2 domain (max(S * scale * log2(e))). Multiply by ln(2) to convert to natural log domain: attn_max = row_max * ln(2). So lse = ln(row_sum) + row_max * ln(2). Verified: LSE err=0.000000.
  13. CuTeDSL MLIR backend cannot handle complex pipeline loops. The MLIR→PTX optimizer has exponential-or-worse behavior for kernels with TMA pipeline acquire/release inside loops. Both Python range() (unrolled) and cutlass.range(unroll=1) (runtime) trigger 3+ hour compilation for hd=512. Consider raw CUDA C++ for complex kernels. Pre-compilation + cubin caching is a viable workaround if the optimizer eventually finishes.
  14. Guard dead code with const_expr. CuTeDSL compiles BOTH branches of Python if statements. Use const_expr(condition) to eliminate dead code at compile time. Critical for: O rescale (only when n_kv_tiles>1), LSE (only when normalize=False), SMEM-P path (only when use_smem_p=True), k_sub path (only when n_k_sub_tiles>1).
  15. External k_sub merge is mathematically impossible. k_sub segments are additive in LOGIT space (S = S_0 + S_1), not attention weight space. You cannot recover softmax(S_0+S_1)@V from softmax(S_0)@V and softmax(S_1)@V. The D5 merge formula works for different token sets (additive in weight space), NOT for partial dot products. In-kernel k_sub accumulation before softmax is the only correct approach.
  16. pv_n_tile reduction is the easiest SMEM knob. At hd>256, reducing pv_n_tile from 256 to 128 shrinks sV and sC by 2× each. Cost: 4 PV GEMM passes instead of 2. But PV is typically not the bottleneck, and this is simpler than SMEM overlap or Q tiling.
  17. O rescale TMEM round-trip with Ld32x32bOp/St32x32bOp is BROKEN. Even a NO-OP round-trip (load O, multiply by 1.0, store back) corrupts data (cos 0.804 at s_k=256). The hand-constructed atoms don't preserve the C-fragment layout during round-trips. CUTLASS correction_rescale uses the same pattern — unclear why theirs works. Workaround: Python KV merge with per-segment LSE (cos 0.999998 for s_k up to 1024).
  18. KV merge formula uses NORMALIZED outputs, not un-normalized. The correct D5 merge for different token sets: O = sum_i [exp(lse_i) * O_i_norm] / sum_i [exp(lse_i)]. Using O_i_unnorm instead of O_i_norm gives cos ~0.91. The un-norm merge only works when both segments share the same row_max (global max), which isn't the case for separate KV segments.
  19. flat_divide + epilogue_tma_store layout mismatch. When using cute.flat_divide to create per-CTA GMEM views with runtime block coordinates (for multi-CTA grid), the resulting tensor layout is incompatible with CUTLASS's epilogue_tma_store pipeline, which expects the layout from local_tile. The tma_partition and epilogue must be refactored together to support multi-CTA grids.
  20. local_tile does not support runtime coordinates. cute.local_tile(mQ, tiler, (runtime_val, None)) fails at trace time. Must use cute.flat_divide(mQ, tiler) instead, which creates a tiled view with all rest dimensions accessible via runtime indexing.
  21. Sink bias domain correction. Adding attn_sink directly to raw logits is wrong — it gets scaled by scale_log2. Fix: add attn_sink / scale to raw logits, so after * scale_log2 it becomes attn_sink * log2(e), correctly multiplying attention weights by exp(attn_sink).
  22. O normalization uses row_sum, NOT LSE. O_norm = O_unnorm / row_sum is correct. O_unnorm * exp(-LSE) is WRONG because O_unnorm is max-shifted (divided by 2^row_max), not raw exp(S) @ V. The kernel now outputs row_sum alongside LSE.
  23. n_comp is compile-time, swa_len is runtime. The n_comp parameter controls const_expr guards in the kernel and cannot vary between segments of the same kernel instance. swa_len is an Int32 scalar and can vary per request. For multi-tile production, use a kernel cache keyed on (n_comp, apply_sink_bias, head_dim, s_k).

Key Lessons

  1. NEVER use find_tmem_tensor_col_offset() as TMEM placement. It returns footprint size, not a safe offset.
  2. FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe.
  3. TMEM allocation must be power of 2.
  4. Square hides bugs. (128,128) worked for every wrong approach. Always test non-square.
  5. St32x32bOp MUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access.
  6. First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.
  7. FMHA P store uses QK C-fragment composition, NOT PV A-fragment. Two aliases, same TMEM.
  8. Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip this.
  9. PRINT THE SHAPES. ALWAYS. Reasoning about TMEM layouts without evidence is how we waste days.
  10. Never assume TMEM round-trips are safe. Verify with NO-OP tests before adding logic.

Stage D: Full Decode Attention (revised May 23)

Key Insight: The Indexer Solves Paging Upstream

The indexer now hands the kernel selected_kv: [T, top_k, head_dim] BF16 — a dense, materialized, dequantized K/V tile. FMHA sees a dense [T, top_k, 512] tile, exactly like Stage A/B's existing k and v inputs. The kernel doesn't need to know it's sparse. Paged TMA, scattered HBM reads, FP8 dequantization — all handled by gather_selected_kv upstream.

The SWA branch is the only "irregular" thing: it reads from the state cache's ring buffer with a position mask. SWA is small (n_win=128 per query), so it's a separate fused branch with a sink-weighted merge.

One FMHA kernel serves all three DSV4 attention types:

  • CSA: compressed_kv = top-k from indexer, swa_kv from cache → sink merge
  • HCA: compressed_kv = all classical pool entries (gather-all mode), swa_kv from cache → sink merge
  • SWA-only (Flash layers 0-1): compressed_kv = empty (top_k=0), only SWA runs. Sink merge degenerates to just o_swa after renormalization.

Build Order

D1 — Parameterize HEAD_DIM + SMEM-P (~1 day, MOSTLY DONE)

Currently hardcoded at 64. Promote to constructor arg, thread through _setup. Test at 64, then 512 (DSV4's real value).

hd≤256: DONE. cos 0.999998 at hd=64/128/256. Both TMEM-P and SMEM-P paths work.

hd=512: BLOCKED. SMEM budget fixed (192KB, fits 232KB limit). Kernel structurally correct (tracer 0.8s). But CuTeDSL's MLIR→PTX backend optimizer hangs for 3+ hours when compiling the k_sub loop. External k_sub merge is mathematically impossible (k_sub segments additive in logit space, not weight space). Need either: (a) pre-compile offline + cache cubin, (b) add no-softmax mode for S accumulation in Python, or (c) write hd=512 path in raw CUDA C++.

Done when: identical result at HEAD_DIM=64 (regression), passes at HEAD_DIM=512 against FP32 oracle.

D2 — Multi-query grid with head packing (~1 day)

Grid changes from (1, 1, 1) to (num_q_blocks, 1, batch). DSV4 is MQA — all n_h=128 query heads share the same K/V. The query-head axis is folded into the M dimension of the Q tile: M_tile = 128 covers M = T * n_h rows. At decode T is small (1-16), so packing heads into M fills the MMA. At prefill T=64, M is already 8192 with heads packed.

Done when: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention against FP32 oracle.

D3 — SWA sequence length mask (~½ day)

The indexer's top_k is fixed (512 for Flash, 1024 for Pro). Compressed-K input is always [T, top_k, head_dim] with the same top_k at compile time.

What varies: the SWA window holds up to n_win=128 tokens but starts with fewer. Add swa_lens: [batch] int32 as kernel input. Mask SWA-branch logits to -inf where swa_idx >= swa_lens[b].

Done when: batched input with varying SWA fill levels (some requests at position 50, some at 5000) produces correct masked output.

D4 — Causal mask on SWA branch (~½ day)

The compressed K the indexer selects is already from s < floor(t/m) (paper eq. 17). The indexer enforces causality at selection time. FMHA sees only causally-valid candidates. The main path has no mask.

The SWA branch needs a causal mask within the window. Add is_causal: bool constructor flag, apply swa_idx > q_pos masking to -inf in the SWA pass.

Done when: prefill mode produces correct output with the causal mask applied to SWA.

D5 — SWA + sink merge COMPLETE (May 26)

Per dsv4/ops/decode_sparse.py:

o = (exp(lse_sparse) * o_sparse + exp(attn_sink) * exp(lse_swa) * o_swa)
    / (exp(lse_sparse) + exp(attn_sink) * exp(lse_swa))

Key insight (May 26): This merge is mathematically identical to a single attention pass over concatenated KV with a logit bias on SWA positions:

S = [S_comp, S_swa + attn_sink]
O = softmax(S) @ [V_comp; V_swa]

This means D5c is a logit bias addition, not a two-pass + merge kernel. D5d (fused in-kernel merge epilogue) is NO LONGER NEEDED.

D5a : normalize flag + LSE + row_sums output. When False, emits un-normalized O + LSE + row_sum.

D5b : Per-row LSE output (all 128 rows now write). Python KV merge with per-row LSE: cos 0.999994.

D5c : Sink bias as logit modification. Parameters: n_comp (compressed KV length, compile-time), apply_sink_bias (compile-time flag), sink_bias (runtime FP32 tensor). Sink bias added to raw logits as attn_sink / scale so after * scale_log2 it correctly becomes attn_sink * log2(e) in the exp2 domain. Multi-tile via Python KV merge: cos 0.999996.

D5d: NOT NEEDED. The sink bias approach makes a fused merge epilogue unnecessary.

Done when: End-to-end kernel produces correct attention against FP32 oracle.

D5 (old) paged TMA — REMOVED. The indexer + gather handles all paging upstream.

Kernel Architecture (after D5 — COMPLETE)

Input: Q [T, n_h, 512], compressed_kv [T, top_k, 512], swa_kv [batch, n_win, 512]
       swa_lens [batch], sink_logits [n_h], request_ids [T]
  │
  └─ Single pass over concatenated KV [compressed_kv; swa_kv]:
      QK → online softmax (with sink bias on SWA, D3/D4 masking) → PV
      → O_unnorm + LSE + row_sum
      → External normalize: O = O_unnorm / row_sum
      → (Multi-tile: Python KV merge across 128-token segments)

Reference Files

  • Sink merge spec: dsv4/ops/decode_sparse.py (formula)
  • SWA decode: dsv4/ops/decode_swa.py
  • Attention reference: dsv4/reference/attention.py
  • CSA attention: dsv4/reference/csa_attention.py

Stage C Note

When implementing D5a, Stage C's epilogue changes from "multiply by 1/row_sum" to "emit un-normalized o + lse". Defer this until D5. Through D1-D4, keep Stage C normalize as-is and test as standalone dense FMHA.


Stage E: Production Extraction (revised May 23)

E1 — File placement

dsv4/kernels/attention/fmha.py. Currently contains FmhaKernel (migrated from test, hd=64 TMEM-P). Will gain parameterized head_dim and SMEM-P path in D1. Constructor takes all dimensions and dtypes, no module-level constants.

E2 — Constructor signature

class FmhaKernel:
    def __init__(
        self,
        head_dim: int,           # 512 for DSV4
        num_query_heads: int,     # 128 for Pro, 64 for Flash
        sliding_window: int,      # 128
        top_k: int,               # 512 (Flash) or 1024 (Pro)
        q_dtype=BFloat16,
        kv_dtype=BFloat16,
        o_dtype=BFloat16,
        qk_acc_dtype=Float32,
        pv_acc_dtype=Float32,
        is_causal: bool = False,  # affects SWA mask only
        cta_group: tcgen05.CtaGroup = tcgen05.CtaGroup.ONE,
        cluster_shape_mn: tuple = (1, 1),
    ):

All architecture-level shapes from config flow into the constructor. No FMHA-internal magic numbers.

E3 — Call signature

def __call__(
    self,
    q: torch.Tensor,            # [T, n_h, head_dim] BF16
    compressed_kv: torch.Tensor, # [T, top_k, head_dim] BF16 — from indexer gather
    swa_kv: torch.Tensor,        # [batch, n_win, head_dim] BF16 — from cache prep
    swa_lens: torch.Tensor,      # [batch] int32
    sink_logits: torch.Tensor,   # [n_h] FP32
    request_ids: torch.Tensor,   # [T] int32 — maps query to its SWA slot
    o: torch.Tensor,             # [T, n_h, head_dim] BF16 — preallocated
    stream: cuda.CUstream,
):

Notably absent: block_table, paged KV, inv_scale, FP8 dequant. All handled upstream.

E4 — Kernel cache + warmup

Mirror dsv4/ops/gemm_runner.py's _compiled_kernel_cache. Key on (head_dim, num_query_heads, top_k, is_causal, ...). Pre-allocate at warmup, reuse at call. For DSV4, the cache has at most ~2 entries (Flash/Pro × causal/non).

E5 — torch.library custom op

@torch.library.custom_op("dsv4::sparse_fmha_with_swa", mutates_args=("o",))
def sparse_fmha_with_swa_op(
    q: torch.Tensor,
    compressed_kv: torch.Tensor,
    swa_kv: torch.Tensor,
    swa_lens: torch.Tensor,
    sink_logits: torch.Tensor,
    request_ids: torch.Tensor,
    o: torch.Tensor,
    runner_id: int,
) -> None:
    runner = get_runner(runner_id)
    runner._run_impl(q, compressed_kv, swa_kv, swa_lens, sink_logits, request_ids, o)

Mutates o (preallocated buffer). Consistent with cudagraphs.

E6 — Reference parity hook

dsv4/reference/attention.py stays as the FP32 oracle. New test: tests/unit/test_fmha_kernel.py.

def test_sparse_fmha_matches_spec(T=64, n_h=128, top_k=1024, n_win=128, hd=512):
    q = torch.randn(T, n_h, hd, dtype=torch.bfloat16, device='cuda')
    ck = torch.randn(T, top_k, hd, dtype=torch.bfloat16, device='cuda')
    swa = torch.randn(4, n_win, hd, dtype=torch.bf16, device='cuda')
    swa_lens = torch.tensor([128, 50, 128, 75], dtype=torch.int32)
    sink = torch.randn(n_h, device='cuda')
    req_ids = torch.randint(0, 4, (T,), dtype=torch.int32)

    # Oracle: pure FP32 spec
    o_sparse, lse_sparse = attention_with_lse_f32(q, ck, ck)
    o_swa, lse_swa = attention_swa_with_lse_f32(q, swa, swa, swa_lens, req_ids)
    e_sink = sink.exp()
    num = lse_sparse.exp().unsqueeze(-1) * o_sparse \
        + e_sink[None, :, None] * lse_swa.exp().unsqueeze(-1) * o_swa
    den = lse_sparse.exp() + e_sink[None, :] * lse_swa.exp()
    expected = num / den.unsqueeze(-1)

    # Kernel
    o = torch.empty_like(expected, dtype=torch.bfloat16)
    fmha = FmhaKernel(head_dim=hd, num_query_heads=n_h, sliding_window=n_win, top_k=top_k)
    fmha(q, ck, swa, swa_lens, sink, req_ids, o, stream=...)

    torch.testing.assert_close(o.float(), expected, atol=5e-3, rtol=5e-3)

E7 — Cleanup

Delete all debug test files. test_fmha_v3.py becomes dsv4/kernels/attention/fmha.py. Only tests/unit/test_fmha_kernel.py remains as the attention test.


Environment

  • Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
  • venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
  • PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
  • Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
  • vLLM repo: /root/dsv4-nvfp4-workspace/vllm (modified for Blackwell)
  • CUTLASS FMHA reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
  • Local CUTLASS clone: /home/openclaw/dev/cutlass