Files
nvfp4-megamoe-kernel/STAGE_D.md

28 KiB
Raw Blame History

Stage D — Parameterized FMHA for DSV4

⚠️ IKEA INSTRUCTIONS — READ EVERY TIME BEFORE CODING

The Workflow (DO NOT SKIP STEPS)

  1. Edit code in ~/dev/nvfp4-megamoe-kernel/dsv4/kernels/attention/fmha.py — this is the ONLY file for the FMHA kernel.
  2. Commit and push:
    cd ~/dev/nvfp4-megamoe-kernel
    git add -A && git commit -m "description" && git push origin master
    
  3. Pull on B200:
    sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
      "cd /root/dsv4-nvfp4-workspace/kernel && git pull origin master"
    
  4. Test on B200:
    sshpass -p '<B200_PASSWORD>' ssh -o StrictHostKeyChecking=no root@45.76.247.107 \
      "cd /root/dsv4-nvfp4-workspace/kernel && source /root/dsv4-nvfp4-workspace/venv/bin/activate && python3 -c '...'"
    
  5. Regression check: After every change, verify hd=64 cos 0.972537 still matches. If it doesn't, the change is WRONG. Revert.

The Rules (BURNED INTO THIS FILE BECAUSE WE BURNED THEM INTO PRODUCTION)

  • NEVER edit files directly on the B200. Edit locally, commit, push, pull, test. Every time.
  • NEVER delete or modify the test files in tests/unit/. They are the regression oracle.
  • NEVER touch drivers, kernels, firmware, or system packages on the B200.
  • CuTeDSL variables defined in if blocks are NOT visible in other if blocks. Even compile-time constants. Define all variables unconditionally before any branching.
  • Always test at hd=64 FIRST. If the proven path (TMEM-P) regresses, nothing else matters.
  • p_cols_fp32 uses pv_mma_tiler[2] (K-dim), NOT pv_mma_tiler[1] (N-dim). We got this wrong twice.
  • PV A-operand major mode is OperandMajorMode.K for TMEM-P. Not a_major from Q.
  • tOrP0 uses 3-dim indexing (None, None, kb), NOT 4-dim (None, None, kb, 0). The 4th mode was already sliced away by tOrP_base[(None,None,None,0)].
  • After every P store to TMEM, call cute.arch.fence_view_async_tmem_store(). Missing this produces NaN.
  • PRINT THE SHAPES. ALWAYS. Run print(f"tensor: shape={cute.shape(tensor)}") inside @cute.kernel at trace time. Reasoning about layouts without evidence is how we waste days.

What We Have Now (Starting Point)

File: dsv4/kernels/attention/fmha.py Class: FmhaKernel State: Parameterized head_dim (D1.0 done). TMEM-P path works at hd=64 (cos 0.972537). SMEM-P path is a stub that zeros sP.

What it does:

  • 6-warp kernel: warps 0-3 (softmax + epilogue), warp 4 (MMA), warp 5 (TMA)
  • QK GEMM → S in TMEM → online softmax → P stored to TMEM via register bridge → PV GEMM → O in TMEM
  • O rescale (per KV tile, kt>0) + O normalization (1/row_sum) via TMEM round-trip
  • Epilogue: TMEM → SMEM → GMEM via TMA store
  • SMEM-P flag wired (use_smem_p), PV source switches between TMEM/SMEM, but register→SMEM copy not implemented

The Problem at hd>64

At hd=64, the QK C-fragment TMEM layout and the PV A-fragment TMEM layout agree — the same threads map to the same columns. P can be written to TMEM using the QK partition and read by PV using the same partition. This is why the register bridge (FP32 backing + BF16 view) works.

At hd=512, P is (128, 128) per KV tile (P's columns = number of KV positions, NOT head_dim). But the PV MMA expects P laid out with 512 columns in its A-operand. The QK C-fragment and PV A-fragment TMEM layouts disagree — different threads own different columns. The register bridge can't write P in a layout that PV can read.

The fix: SMEM-P path. P goes through SMEM instead of TMEM:

  1. Softmax computes P in registers (QK C-fragment partition)
  2. Write P to SMEM using the p_smem_s layout (PV A-operand SMEM layout)
  3. MMA warp reads P from SMEM via tCrP = pv_mma.make_fragment_A(sP)
  4. PV GEMM uses tcgen05.OperandSource.SMEM instead of OperandSource.TMEM

The SMEM rendezvous: SMEM is the meeting point. Softmax threads write at logical (row, col) addresses. MMA reads at the same addresses. A barrier in between. No cross-warp message passing needed — just write-to-address, barrier, read-from-address.

The missing piece (the D1 work): The register→SMEM copy. The softmax warps have P values in QK C-fragment partition. They need to write to SMEM with PV A-operand layout. This requires a TiledCopy that partitions threads by QK's C-fragment and targets the P SMEM layout.

# The correct approach:
store_atom = cute.make_copy_atom(tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma)  # NOT pv_mma!
# This gives threads partitioned by QK C-fragment, writing to the P SMEM layout

Then: softmax threads write their P values through this copy → barrier → MMA reads from SMEM.


TMEM Column Budget at hd=512

This MUST be calculated before writing a single line of SMEM-P code.

TMA tensor tensor core (TMEM) has 512 columns per CTA. Each column is 32 bits wide.

At hd=64 (TMEM-P path):

  • S (QK acc): 128 cols FP32
  • P (softmax output): 64 cols FP32 (= pv_mma_tiler[2] * BF16_width / FP32_width = 128 * 16/32 = 64... wait, let me recalculate)
    • p_cols_fp32 = pv_mma_tiler[2] * q_dtype.width // qk_acc_dtype.width
    • pv_mma_tiler = (128, 64, 128). pv_mma_tiler[2] = 128
    • p_cols_fp32 = 128 * 16 / 32 = 64
    • P starts at offset 32 (after 32 unused cols? No, S is at 0 with 128 cols, P at offset 32 overlaps??)
    • Actually: tmem_p0_offset = 32 means P starts at TMEM col 32. But S uses cols 0-127. P at 32 means they OVERLAP. This works because S is consumed before P is written (softmax reads S, then writes P to same TMEM region).
  • After P: o_after = max(s_cols=128, p_end=32+64=96) = 128. tmem_o0_offset = ((128 + 31) // 32) * 32 = 128
  • O (PV acc): find_tmem_tensor_col_offset(tOtO) at hd=64 ≈ 128 cols FP32
  • Total: 128 (O offset) + 128 (O size) = 256 cols. Fits in 512.

At hd=512 (SMEM-P path):

  • P is NOT in TMEM. S and O share TMEM (sequential, not concurrent).
  • S (QK acc): 128 cols FP32 (same as hd=64 — QK is always (128, 128))
  • O (PV acc): at hd=512, PV is (128, 512). PV MMA C-fragment is (128, 512) FP32 = 512 cols? NO.
    • tOtO = pv_thr.make_fragment_C(pv_as) where pv_as = pv_thr.partition_shape_C((128, 512))
    • The C-fragment for a tcgen05 MMA with shape (128, 512) in FP32:
      • M=128 → 4 warps × 32 threads = 128 rows, each thread owns 1 row
      • N=512 → 512/32 = 16 TMEM columns per thread? No, tcgen05 MMA writes (32, 32) tiles.
      • For (128, 512) MMA: 4 M-tiles × 16 N-tiles = 64 (32×32) subtiles
      • Each subtile uses 32 TMEM columns. But they're distributed across warps.
      • find_tmem_tensor_col_offset(tOtO) gives the actual footprint.
    • MUST PRINT THIS ON THE B200. Do not guess. Run a shape probe.
  • If O needs ~512 cols: S (128) + O (512) = 640 > 512. DOES NOT FIT.
    • Fix options:
      1. Drop kv_stage from 2 to 1 — frees SMEM but loses K/V double-buffering. TMEM budget unchanged.
      2. Split O into halves: process (128, 256) PV twice, each O tile is 256 cols. S(128) + O(256) = 384 < 512.
      3. Process S and O sequentially: after softmax consumes S, O can reuse S's TMEM region. O at offset 0, 512 cols. Total = 512. But only if we don't need S anymore when writing O (true — softmax is done before PV starts per KV tile).

Plan: SMEM-P path reuses S's TMEM for O. After softmax reads S and writes P to SMEM, S's TMEM region (cols 0-127) is dead. PV writes O starting at col 0. O at hd=512 needs ~256-512 cols (must measure). If O fits in cols 0-511 with S gone, we're golden.

Action item: Run shape probe on B200 before coding SMEM-P at hd=512.

# Shape probe script to run on B200:
import torch, math, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.nvgpu.tcgen05 as tcgen05
from cutlass import BFloat16, Float32, LayoutEnum

a_major = LayoutEnum.ROW_MAJOR  # adjust to match
b_major = LayoutEnum.ROW_MAJOR
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, a_major, b_major, Float32, tcgen05.CtaGroup.ONE, (128,512), tcgen05.OperandSource.SMEM)
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C((128, 512))
tOtO = pv_thr.make_fragment_C(pv_as)
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
o_cols = find_tmem_tensor_col_offset(tOtO)
print(f"hd=512 PV C-fragment: pv_as={pv_as}, tOtO.layout={tOtO.layout}, o_cols={o_cols}")
# Also print tOtO shape
print(f"tOtO shape: {cute.shape(tOtO)}")

Correctness Gaps — Must Close Before Production

These are NOT optimization gaps. They are cases where the current code produces numerically wrong outputs vs the trained checkpoint.

CG-1: SwiGLU Clamping Missing from Fused Kernel ⚠️ CRITICAL

What: FusedSwiGLUScaledGroupedGemmKernel.__init__ stores self.swiglu_limit but the SwiGLU compute block (lines 21852200 in fused_swiglu.py) never references it. The reference path in dsv4/reference/moe_pipeline.py correctly applies clamp(max=swiglu_limit) to gate and clamp(min=-limit, max=+limit) to up. The fused kernel silently skips it.

Why it matters: Paper §4.2.3 explicitly says weights were trained with the gate component capped at 10 and the linear component clamped to [10, 10]. Without clamping, the fused kernel produces different outputs than the reference at large activation values.

Fix (2 lines in the fused kernel):

# After computing silu_result (gate subtile):
silu_result = cute.math.fmin(silu_result, swiglu_limit)

# Before the gate*up multiply (up subtile):
acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit)

Where: dsv4/kernels/gemm/fused_swiglu.py, lines ~2192 (gate branch) and ~2198 (up branch).

Status: 🔴 NOT FIXED. Do this IMMEDIATELY — it's a 2-line fix and affects all MoE layer outputs.

CG-2: FMHA at hd=512 SMEM-P is a Stub ⚠️ CRITICAL

What: FmhaKernel with use_smem_p=True zeros sP and comments "PV will produce garbage." DSV4 head_dim is 512 (§4.2.1). The kernel literally cannot produce correct output at the production head dimension.

Why it matters: This is the D1 work. The path forward is correct (make_tiled_copy_C(store_atom, qk_mma) to partition P registers for SMEM staging). But TMEM column budget at hd=512 must be verified first (see budget section above).

Status: 🔴 D1.2D1.3 TODO. This document IS the plan.

CG-3: SWA + Sink Merge Not Fused in FMHA ⚠️ CRITICAL

What: The DSV4 attention design (§2.3.3) requires merging compressed top-k attention with sliding window attention via sink weights. Currently the FMHA kernel only does dense attention (one KV source, one softmax, one PV, normalize). The sink merge is implemented in Python fallback (decode_sparse.py) but NOT in the production kernel path.

Why it matters: Without SWA+sink merge, the compressed branch alone cannot capture local dependencies. The paper is explicit: "Additional Branch of Sliding Window Attention." Every CSA and HCA layer produces wrong output without this.

Fix plan (D5, ordered by priority):

  1. D5a: Emit un-normalized o + lse instead of normalized o. This is the SINGLE MOST IMPORTANT structural change — once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness. Keep normalize as a flag so standalone tests still work.
  2. D5b: Run kernel twice externally (compressed_kv + swa_kv), merge in Python. End-to-end correctness without touching kernel structure. This is the correctness baseline.
  3. D5c: Fuse two passes into one kernel launch (Q stays in SMEM, two sequential MMA loops). Pure optimization.
  4. D5d: Fuse sink merge into kernel epilogue. Pure optimization.

Status: 🔴 D5 TODO. D5a must be done FIRST — it unblocks D5b which gives us correctness.

CG-4: Inverse RoPE Verification ⚠️ HIGH

What: inverse_rope_bf16 in dsv4/ops/rope.py applies the conjugate rotation to the last rope_dim=64 dims of each head output. The math looks correct: inv[2i] = x[2i] * cos + x[2i+1] * sin, inv[2i+1] = -x[2i] * sin + x[2i+1] * cos. This is the standard inverse rotation for interleaved (GPT-J) RoPE.

What needs verifying:

  1. The positions argument must be the same positions used for the forward RoPE on Q and K. The inverse RoPE applies RoPE with position = +position (not -position). The "inverse" is the conjugate rotation, not a negated angle. The code uses cos_sin_cache[positions, :] which is the same table as forward RoPE. For conjugate rotation, we need cos(θ) and sin(θ) at the SAME position, then flip the sign on the sin terms in the odd positions. The current code does this correctly: inv_odd = -o_even * sin_all + o_odd * cos_all.
  2. The nope_dim=448 / rope_dim=64 split must match the model's actual split. If a layer uses a different split, the inverse RoPE would rotate the wrong dims.
  3. The cos_sin_cache must be the same cache used for forward RoPE. If there's any offset or indexing difference, the angles won't match.

Action: Write a unit test that: (1) applies forward RoPE to random input, (2) applies inverse RoPE, (3) verifies the result matches the original. This is a round-trip test and catches both sign and indexing errors.

Status: 🟡 Code looks correct but UNTESTED. Add a round-trip unit test.

CG-5: Mixed-Precision KV (BF16 RoPE + FP8 NoPE) — FMHA Load Path ⚠️ HIGH

What: Paper §2.3.4: KV cache stores dims 0..447 as FP8 and dims 448..511 as BF16. The PagedKVPool already implements this split: entries_fp8 (uint8) + entries_rope (BF16) + inv_scale (FP32). The current decode_sparse.py fallback dequantizes in Python before calling the kernel.

Why it matters for FMHA: The FmhaKernel currently takes contiguous BF16 K/V tensors. At production, the kernel must handle the mixed-precision KV directly — reading FP8 + BF16 from the paged cache and dequantizing on the fly during TMA→SMEM transfer. This is the proper Blackwell pattern: TMA loads FP8 to SMEM, on-the-fly dequant in the SMEM→register path, then MMA.

The proper approach:

  1. TMA loads FP8 NoPE dims to SMEM slot 0
  2. TMA loads BF16 RoPE dims to SMEM slot 1 (or separate TMA)
  3. Dequantize FP8 → BF16 in SMEM (vectorized, per-entry inv_scale multiply)
  4. Concatenate [NoPE, RoPE] in SMEM (or use two separate SMEM regions with strided MMA)
  5. MMA reads contiguous BF16 from SMEM

Prerequisite: This requires D1 (SMEM-P) and D5 (sink merge) to be working first. The mixed-precision load path replaces the current "all BF16" K/V input with the real paged cache format.

Status: 🔴 NOT IMPLEMENTED. Plan as D6 (after D5). The current test harness passes contiguous BF16 K/V, which is fine for correctness testing. The FP8 dequant in SMEM is a performance + memory optimization that doesn't affect numerical correctness (FP8 dequant is well-defined).

CG-6: Per-Token valid_lens in Indexer for Prefill ⚠️ MEDIUM

What: score_topk.py has a TODO that broadcasts request 0's valid_lens for prefill (T > B). For batched prefill, different requests have different numbers of compressed entries in the pool. Broadcasting the first request's count means other requests either score garbage entries (too many) or miss valid ones (too few).

Why it matters: Prefill correctness blocker. The indexer will select wrong entries for all requests except the first in a batch.

Fix: Map each query token to its request ID, then look up valid_lens[request_id]. The request_ids: [T] int32 tensor already exists in the cache handle. The indexer kernel needs this as an input.

Status: 🔴 NOT FIXED. This is indexer scope, not FMHA scope. Track separately.


Performance Soft Spots — Important But Not Correctness

These affect throughput but not numerical correctness. Tracked for Stage F+.

PS-1: Indexer Score+TopK is Scalar CUDA — Not Blackwell Native 🔴

What: indexer_score_topk.cu is a CUDA-core scalar implementation:

  • Triple loop: for h in n_heads, for g in n_groups, for b in 8
  • FP4 nibble dequant to FP32, FP32 dot product
  • Shared-memory min-heap protected by single s_lock atomicCAS spinlock
  • For 1M-context: ~250K compressed entries scored per query token

Why it's the biggest perf leak: The dot products should use tensor cores. The heap spinlock won't scale to top_k=1024 with hundreds of thousands of candidates.

The correct approach: DeepGEMM's fp8_mqa_logits / fp8_paged_mqa_logits pattern (Sept 2025 PR for V3.2 indexer). Weighted ReLU MQA logits computed with tensor cores, paged variant for decode. Our V4 NVFP4 variant should be that pattern with FP4 inputs and tcgen05 MMA. Beyond the MMA, the heap needs replacing with per-warp partial top-k merged via reduction tree, or radix-select.

Status: Tracked for Stage F (post Stage E). Not blocking D1D5.

PS-2: decode_sparse.py BlackwellSparseDecodeKernel is Misleading 🟡

What: dsv4/ops/decode_sparse.py contains BlackwellSparseDecodeKernel — a CuTeDSL kernel that does scalar for d in range(HD): dot += q_val * k_val with no tensor cores. It also has a _fallback_sparse_sdp Python path that uses F.scaled_dot_product_attention.

Why it's misleading: The class name says "Blackwell" but it uses zero Blackwell tensor acceleration. Anyone reading the codebase would assume this is the production kernel. It's a stale early-exploration kernel superseded by FmhaKernel.

Action: Delete BlackwellSparseDecodeKernel and its CuTeDSL code. Keep _fallback_sparse_sdp as a reference implementation (rename to _reference_sparse_sdp_attention). The FMHA kernel in dsv4/kernels/attention/fmha.py is the real path. Do this cleanup as part of E7.

Status: Low urgency. Track for E7 cleanup.

PS-3: mHC Mixing Uses torch.bmm with n_hc=4 🟢

What: mHC mixing operations (A_l @ X_l, B_l @ X_l, C_l ⊗ F_out) use torch.bmm with tiny n_hc=4 inner dimension.

Why it matters: For decode (T=1) this is fine — tiny matmul. For prefill it leaves throughput on the floor. But prefill is not the immediate priority.

Status: Lowest priority of the soft spots. Track for Stage G (prefill optimization).


Stage D Build Order (REVISED)

Priority Principle: Correctness First, Then Performance

D1 (hd=512) and D5 (SWA+sink merge) are both correctness-critical. But D5 depends on D1 (can't merge SWA if the kernel can't even run at hd=512). CG-1 (SwiGLU clamping) is a 2-line fix with no dependencies — do it first.

D0 — SwiGLU Clamping Fix (CG-1) DO THIS FIRST

  • Add clamping to fused SwiGLU in dsv4/kernels/gemm/fused_swiglu.py
  • Gate subtile: silu_result = cute.math.fmin(silu_result, swiglu_limit) after SiLU compute
  • Up subtile: acc_vec = cute.math.fmin(cute.math.fmax(acc_vec, -swiglu_limit), swiglu_limit) before gate*up multiply
  • Verify: cute.math.fmin / cute.math.fmax work with CuTeDSL vectorized code (they should — they're elementwise)
  • Test: fused MoE output matches reference with clamping at swiglu_limit=10.0
  • Commit with clear message: "fix: add SwiGLU clamping to fused kernel (paper §4.2.3)"

D1 — Parameterized HEAD_DIM + SMEM-P (CG-2)

D1.0 — Replace HEAD_DIM constant with constructor parameter DONE

Already in the kernel. head_dim is a constructor arg. TMEM-P path works at hd=64.

D1.1 — Add SMEM-P path behind use_smem_p flag WIRED (stub)

The use_smem_p flag exists. PV source switches between TMEM/SMEM. TMEM layout adjusts. But the register→SMEM copy is a stub that zeros sP.

D1.2 — TMEM Column Budget Verification 🔨 DO THIS BEFORE CODING

  • Run shape probe on B200: find_tmem_tensor_col_offset(tOtO) at hd=512
  • Print pv_as, tOtO.layout, o_cols at hd=128, 256, 512
  • Calculate: can S(128) and O(???) share TMEM at hd=512?
  • If O > 384 cols: plan for split-PV (two (128, 256) passes)
  • Document the budget numbers HERE in this file

D1.3 — Implement register→SMEM copy for P (THE HARD PART)

  • Build tiled_p_copy = cute.make_tiled_copy_C(store_atom, qk_mma) — QK MMA partitions threads
  • Print the shapes: cute.shape(tiled_p_copy), partition source/dest shapes
  • Partition sP with tiled_p_copy as destination
  • In softmax warps: after computing P in registers, write to SMEM via tiled_p_copy
  • Add p_smem_ready_bar NamedBarrier: softmax arrives after write + fence, MMA waits before PV GEMM
  • In MMA warp: read P from SMEM via tCrP = pv_mma.make_fragment_A(sP)
  • Test: hd=64, n=128, use_smem_p=True → compare against TMEM-P result
  • Test: hd=128, n=128 → test against FP32 oracle
  • Test: hd=256, n=128 → test against FP32 oracle

D1.4 — Multi-PV-tile for hd>256

  • Add pv_n_tile = min(head_dim, 256) and n_pv_tiles = head_dim // pv_n_tile to __init__
  • For hd=512: 2 PV tiles of (128, 256) each
  • Strategy: kernel processes one PV N-tile per launch. Python orchestrates the tiles.
    • Pass 0: V[:, 0:256] → output[:, 0:256], QK + softmax + PV for cols 0-256
    • Pass 1: V[:, 256:512] → output[:, 256:512], QK + softmax + PV for cols 256-512
    • QK and softmax run identically both passes (P is the same). Only PV changes.
  • Alternative (if SMEM-P allows): keep P in SMEM between PV tiles. Run QK+softmax once, PV twice.
  • Test: hd=512, n=128 → correct output against FP32 oracle

D1.5 — Correction Epilogue: Fix TMEM Layout Mismatch (3% Error)

The current TMEM round-trip (Ld32x32bOp + St32x32bOp hand-constructed atoms) introduces 3% error at hd=64 (cos 0.973). The proper fix is the CUTLASS correction_epilog pattern:

TMEM --get_tmem_load_op--> reg (normalize + FP32→BF16) --get_smem_store_op--> SMEM --TMA--> GMEM

This is a one-way trip. No TMEM round-trip. No layout mismatch.

  • Investigate: can we use get_tmem_load_op + get_smem_store_op paired atoms?
  • Investigate: can we inject inv_row_sum into epilogue_tma_store pipeline?
  • Investigate: pre-compute TMA partitioning outside if warp_idx blocks (region isolation workaround)
  • Test: hd=64, n=128 → cos should jump from 0.973 → ~0.9999
  • Test: hd=64, n=256 → cos should jump from 0.793 → ~0.9999

Note: This is NOT blocking for D2D5. The 3% error is a precision issue, not a correctness issue (the attention math is right, the epilogue just introduces rounding). Fix it properly rather than hacking it. But don't let it block the D2D5 pipeline.

D2 — Multi-Query Grid with Head Packing

  • Grid changes from (1, 1, 1) to (num_q_blocks, 1, batch)
  • DSV4 is MQA: all 128 query heads share same K/V
  • Head axis folded into M dimension of Q tile: M_tile = 128 covers M = T * n_h rows
  • At decode T=1: M = 1 × 128 = 128 — one Q block covers all heads.
  • At prefill T=64: M = 64 × 128 = 8192 — 64 Q blocks. Needs grid loop.
  • Test: batch=4, T=64, n_h=128, num_kv_heads=1 produces correct attention

D3 — SWA Sequence Length Mask

  • Add swa_lens: [batch] int32 kernel input
  • Mask SWA-branch logits to -inf where swa_idx >= swa_lens[b]
  • Test: batched input with varying SWA fill levels (position 50 vs 5000)

D4 — Causal Mask on SWA Branch

  • Add is_causal: bool constructor flag
  • Apply swa_idx > q_pos masking to -inf in SWA pass
  • Main path has NO mask (indexer enforces causality upstream)
  • Test: prefill mode produces correct output with causal mask

D5 — SWA + Sink Merge (CG-3) ⚠️ THE WHOLE POINT OF V4 ATTENTION

D5a — Emit un-normalized o + lse DO THIS IMMEDIATELY AFTER D1

This is the single most important structural change. Once the kernel can output (o_unnorm, lse), even a Python merge gives end-to-end correctness.

  • Change epilogue: instead of O *= 1/row_sum, emit O un-normalized and lse = log(row_sum) + row_max as a separate output
  • Add normalize: bool constructor flag (default: True for backward compat, False for merge mode)
  • When normalize=False: skip the TMEM round-trip for normalize. O stays as PV @ V (un-normalized). lse written to a separate GMEM buffer.
  • Test: normalize=True → identical to current behavior (regression)
  • Test: normalize=Falseo_unnorm / exp(lse).unsqueeze(-1)o_normalized (verify math)

D5b — Python merge (correctness baseline)

  • Run FmhaKernel twice: once with compressed_kv, once with swa_kv
  • Merge in Python:
    exp_lse_sparse = lse_sparse.exp()
    exp_lse_swa = lse_swa.exp()
    exp_sink = sink_logits.exp()
    o = (exp_lse_sparse * o_sparse + exp_sink * exp_lse_swa * o_swa) / (exp_lse_sparse + exp_sink * exp_lse_swa)
    
  • Test against FP32 oracle that does sparse+SWA+sink merge
  • This gives us end-to-end correctness. Everything after is optimization.

D5c — Fuse two passes into one kernel launch

  • Q loaded once to SMEM, used by both compressed and SWA MMA loops
  • Two sequential QK→softmax→PV passes in one kernel invocation
  • K/V have two sources: compressed (contiguous BF16) and SWA (from cache)
  • For now: dequantize SWA in a small prep kernel before FMHA, FMHA sees two contiguous BF16 sources
  • [ Test: output matches D5b Python merge

D5d — Fuse sink merge into kernel epilogue

  • TMEM holds two O accumulators + two row_max/row_sum per row
  • Verify TMEM column budget: two O + two (row_max, row_sum) at hd=512
  • Sink merge in TMEM: O = (exp(lse1) * O1 + exp(sink) * exp(lse2) * O2) / (exp(lse1) + exp(sink) * exp(lse2))
  • [ Test: output matches D5b Python merge

D6 — Mixed-Precision KV Load Path (CG-5)

  • TMA loads FP8 NoPE dims to SMEM slot 0
  • TMA loads BF16 RoPE dims to SMEM slot 1
  • Dequantize FP8 → BF16 in SMEM (vectorized * inv_scale)
  • Concatenate [NoPE, RoPE] in SMEM
  • MMA reads contiguous BF16 from SMEM
  • Test: FP8+BF16 split input matches pure BF16 input (dequant is transparent)
  • Prerequisite: D1 (SMEM-P) and D5 (sink merge) working first

Inverse RoPE Verification (CG-4) — Separate from D1D6

  • Write unit test: tests/unit/test_inverse_rope.py
  • Round-trip test: forward RoPE → inverse RoPE → verify ≈ original
  • Multi-head test: verify only last 64 dims are rotated
  • Position test: verify cos_sin_cache indexing is correct for positions > 0
  • This is a standalone test, not a kernel change. Can be done anytime.

Per-Token valid_lens (CG-6) — Indexer Scope, Not FMHA

  • Add request_ids: [T] int32 to indexer kernel input
  • Look up valid_lens[request_ids[t]] per query token
  • Replace the current broadcast of valid_lens[:1]
  • Test: batched prefill with different sequence lengths per request
  • Tracked separately from FMHA Stage D. This is indexer work.

Correctness Gap NOT in This Project

CG-7: Indexer Rewrite (PS-1) — Stage F

The indexer needs a full rewrite from scalar CUDA to tcgen05 MMA + radix-select. This is a major work item (2-3 weeks) that is out of scope for Stage D.

Reference: DeepGEMM's fp8_mqa_logits / fp8_paged_mqa_logits (Sept 2025 PR for V3.2 indexer). Our V4 variant: same pattern with FP4 inputs and tcgen05 MMA.


Execution Order (Top to Bottom)

# Task Blocks Est.
D0 SwiGLU clamping (CG-1) Nothing — do first 30 min
D1.2 TMEM budget probe at hd=512 D1.3 1 hr
D1.3 Register→SMEM copy for P D1.4, D2 1-2 days
D1.4 Multi-PV-tile hd>256 D2 1 day
D1.5 Correction epilog fix (3% → 0.01%) Nothing (can parallel) 1-2 days
D2 Multi-query grid + head packing D3 1 day
D3 SWA sequence length mask D5 ½ day
D4 Causal mask on SWA D5 ½ day
D5a Emit un-normalized o + lse D5b 1 day
D5b Python merge (correctness) D5c ½ day
D5c Fuse two passes in one launch D5d 2 days
D5d Fuse sink merge in epilogue D6 2 days
D6 Mixed-precision KV load E1 2 days
CG-4 Inverse RoPE round-trip test Nothing 2 hrs
CG-6 Per-token valid_lens (indexer) Nothing ½ day

Critical path: D0 → D1.2 → D1.3 → D1.4 → D5a → D5b (end-to-end correctness)

D1.5 (correction epilog) and CG-4 (RoPE test) can happen in parallel with D2D4.