Files
nvfp4-megamoe-kernel/README.md

14 KiB
Raw Blame History

DSV4 Inference Kernel

Architecture

DSV4 is not MLA. It uses CSA (Compressed Sparse Attention, m=4) and HCA (Heavily Compressed Attention, m=128). KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.

DSV4 inference pipeline — component status
==========================================

Legend:
 [✓] built and tested
 [~] partial — reference or seam exists, native pending
 [✗] to build


 ┌────────────────────────────────────┐
 │ [✗] Embedding + mHC init          │
 │ token embed + n_hc=4 streams      │
 └────────────────┬───────────────────┘
                  │
                  ▼
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 01 of Pro, alternating CSA / HCA after              │
│                                                                      │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [~] Norms + RoPE             RMSNorm + partial RoPE           │  │
│ │ [✓] Q / KV projection        NVFP4 linears + LoRA             │  │
│ │ [~] Token compressor         CSA m=4 / HCA m=128             │  │
│ │ [✗] Indexer + top-k          CSA only, FP4 QK                 │  │
│ │ [~] FMHA core                QK → online softmax → PV         │  │
│ │                              + SWA branch + sink merge         │  │
│ │ [✓] Output projection        inv RoPE + wo_a grouped + wo_b   │  │
│ └────────────────────────────────────────────────────────────────┘  │
│                                                                      │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [~] Pre-FFN norm              RMSNorm                          │  │
│ │ [✗] Router                    sqrt(softplus) + topk + hash     │  │
│ │ [✓] Routed MoE               fused SwiGLU L1 + L2             │  │
│ │ [✓] Shared expert            NVFP4 single-group GEMM          │  │
│ └────────────────────────────────────────────────────────────────┘  │
└──────────────────────────────────┬───────────────────────────────────┘
                                  │
                                  ▼
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘

┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management                                             │
│ • state cache: SWA window + uncompressed tail per layer             │
│ • classical paged cache: lcm(m, m) = 128 tokens per block         │
│ • heterogeneous layout per layer                                    │
└──────────────────────────────────────────────────────────────────────┘


Summary
-------
 Built  [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
               shared expert
 Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
               pre-FFN norm
 To build [✗] : 8 — embedding+init, indexer+top-k, router,
               final norm, LM head, MTP, sampler, KV cache

Status (May 22, 2026 — 16:30 UTC)

Stage Status Description
A COMPLETE Q@K^T via tcgen05.mma → TMEM → GMEM
B COMPLETE QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving)
C ⚠️ SINGLE-TILE ONLY Real online softmax works for n=128 (cosine 0.993-0.996). Multi-tile (n>128) broken.
C' 🔨 IN PROGRESS Multi-tile TMA indexing fix + correction warps. See below.
D TODO Full decode attention: paged KV cache, multi-query, causal mask
E TODO Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge

Package Structure

dsv4/
├── kernels/          Pure GPU code (CuTeDSL @cute.jit, .cu files)
│   ├── gemm/           NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│   ├── attention/      FMHA kernel (stub — extraction is Stage E)
│   ├── compressor/     CSA/HCA token-level compressor
│   ├── decode/         Decode-time attention (sparse, SWA — future)
│   └── cuda/           Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/              PyTorch ↔ kernel bridges
│   ├── quantize.py      BF16 ↔ NVFP4 conversion, scale factors
│   ├── layouts.py       Scale swizzle, gate/up interleave, K-major, offsets
│   ├── gemm_runner.py   Warmup, compile, run grouped/fused GEMMs
│   ├── custom_ops.py    torch.library.custom_op registrations
│   ├── decode_sparse.py native_sparse_decode dispatcher
│   ├── decode_swa.py    native_swa_decode dispatcher
│   ├── rope.py          Forward + inverse RoPE
│   └── topk.py          Python wrapper for sparse_topk_metadata.cu
├── layers/           nn.Module-style components
│   ├── linear.py        Nvfp4Linear
│   ├── grouped_linear.py Nvfp4GroupedLinear
│   ├── moe.py           Nvfp4MoE
│   ├── shared_expert.py Nvfp4SharedExpert
│   ├── mhc.py           mHCLayer
│   └── (stubs: attention, ffn, router, norm, embedding)
├── model/            Model assembly (stubs — Phase 1)
├── cache/            KV cache infra (stubs — Phase 3)
├── loader/           Checkpoint I/O (stubs — Phase 1)
└── reference/        Slow PyTorch oracles (never imported by production code)
    ├── attention.py     RoPE, KV cache, causal attention, SWA
    ├── csa_attention.py CSA/HCA sparse attention
    ├── compressor.py    Compressor PyTorch example
    └── moe_pipeline.py  MoE pipeline reference

Mental model: kernels/ops/layers/model/ (dependency flows left to right). reference/ and loader/ are sidecars.


Active Test Files

FMHA (Stages A/B/C) — in tests/unit/

File Stage Status
test_fmha_v3.py A+B Full QK→identity softmax→PV, cosine 0.999999
test_fmha_v3_12w.py A+B 12-warp QK→PV, cosine 0.999999
test_fmha_v3_stage_c_full.py C Real online softmax + O normalization, cosine 0.993-0.996
test_fmha_v3_stage_c_min.py C 🔨 Early 12-warp pipeline (broken pipeline state)
test_pv64_with_softmax.py B (128,64) PV, single AB pipeline
test_128_128_vdiag.py A+B (128,128) PV baseline
test_qkonly.py A QK with split Q/KV pipelines
test_qk_softmax.py A+B QK + identity softmax, no PV

MoE / GEMM — in tests/unit/

File What
test_cutedsl.py NVFP4 grouped GEMM kernel
cudagraph_test.py Cudagraph capture + replay
layertest.py Per-layer correctness
test_custom_op.py torch.library custom ops
test_compile_custom_op.py Compile + warmup
test_fp4_roundtrip.py BF16 → NVFP4 → BF16 roundtrip
test_interleave.py Gate/up weight interleaving
test_interleave_gemm.py Interleaved GEMM correctness
test_fused_step1.py Fused SwiGLU GEMM

Archived Tests

tests/archive/ contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.


Test Harness

Scripts in tests/ for running tests on the B200 (root@45.76.247.107):

run_test.sh — Run a test in a screen session

# On the B200:
cd /root/dsv4-nvfp4-workspace/kernel
bash tests/run_test.sh tests/unit/test_fmha_v3.py

What it does:

  1. Kills any existing kernel-test screen and SIGKILLs all child processes (handles deadlocked GPU procs that ignore SIGHUP)
  2. Deletes the old log file
  3. Starts a new screen -dmS kernel-test running the test
  4. Logs output to /tmp/kernel-test.log
  5. Verifies the screen started

check_log.sh — Check test progress

bash tests/check_log.sh

Shows the log contents and whether the screen is still running.

Local → B200 workflow

# 1. Edit locally, commit, push
cd ~/dev/nvfp4-megamoe-kernel
git add -A && git commit -m "my change" && git push

# 2. SSH to B200, pull, run
ssh root@45.76.247.107
cd /root/dsv4-nvfp4-workspace/kernel && git pull
bash tests/run_test.sh tests/unit/test_fmha_v3_stage_c_full.py

# 3. Check results
bash tests/check_log.sh

Stage C: Online Softmax — SINGLE-TILE ONLY

What We Have

Working real softmax for single KV tile (n=128) in test_fmha_v3_stage_c_full.py: cosine 0.993-0.996. Multi-tile (n>128) is broken — see blocker below.

Multi-Tile Blocker: TMA GMEM Tile Indexing

The original TMA partition slices tBgK with (None, 0, None, 0) which hardcodes the GMEM iteration dimension to tile 0. This means TMA always loads K/V from the first 128 tokens regardless of kt. Output is identical for all n>128.

Why you can't just index with kt: CuTeDSL's TMA copy API accepts pipeline state values (like kh.count) as TMA coordinates but does NOT accept Python int from range(). Indexing with kt fails at operation creation.

Fix (Mike): Combined K+V barrier — one acquire_and_advance per kt, two cute.copy calls sharing kvh.barrier. With no interleaving, kvh.count naturally equals kt and stays a first-class pipeline state value. See fmha_v3_stage_c_example2.py.

Current status of fix: Compiles but deadlocks at runtime (even n=128). The 3-way sync between acc_pipe, softmax_done_bar, and final_o_bar needs debugging. Fallback: kh.count // 2 in the original interleaved kernel (CuTeDSL Int32 overloads __floordiv__ in recent versions).

Files

File Status Notes
test_fmha_v3_stage_c_full.py OK n=128 only Working real softmax + O normalization
fmha_v3_stage_c_example1.py BROKEN multi-tile First fix attempt, TMA still loads tile 0
fmha_v3_stage_c_example2.py DEADLOCK Combined K+V barrier, compiles but deadlocks
test_fmha_v3_stage_c2.py DEADLOCK 12-warp pipeline, compiles but deadlocks
test_fmha_v3_12w.py OK n=128 only Identity softmax baseline

Current Architecture (6-warp)

Warps 0-3: Softmax + Epilogue Warp 4: MMA (QK, PV) Warp 5: TMA (Q/K/V load)

Target Architecture (12-warp, production)

Warps 0-3: Softmax, Warps 4-7: Correction, Warp 8: MMA, Warp 9: TMA, Warp 10: Epilogue, Warp 11: Empty

CuTeDSL Constraints (hard-won)

  1. vectorize=True loops: ONLY load/store/print
  2. .reduce(cute.ReductionOp.MAX): reduces ENTIRE C-fragment to scalar — global max, not per-row
  3. cute.arch.fmax: impure for vectorizer — use plain range() loop
  4. TMA cute.copy accepts pipeline state values as coordinates but NOT Python int
  5. tBgK[(None, 0, None, 0)] hardcodes GMEM iteration to tile 0
  6. softmax_done_bar NamedBarrier is reusable across tiles

Remaining for C' (Production Stage C)

  1. Fix multi-tile TMA — combined K+V barrier or kh.count // 2
  2. Fix runtime deadlock in example2 (acc_pipe + final_o_bar sync)
  3. Cross-warp reduction for row_max and row_sum
  4. Correction warps for multi-tile KV (online O rescale in TMEM)
  5. 12-warp layout with separate softmax/correction/epilogue warps

TMEM Layout

Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (64 FP32) | Col 128+: O (PV acc, 64 FP32)


Key Lessons

  1. NEVER use find_tmem_tensor_col_offset() as TMEM placement. It returns footprint size, not a safe offset.
  2. FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe.
  3. TMEM allocation must be power of 2.
  4. Square hides bugs. (128,128) worked for every wrong approach. Always test non-square.
  5. St32x32bOp MUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access.
  6. First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.
  7. FMHA P store uses QK C-fragment composition, NOT PV A-fragment. Two aliases, same TMEM.
  8. Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip this.

Environment

  • Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
  • venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
  • PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
  • Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
  • vLLM repo: /root/dsv4-nvfp4-workspace/vllm (modified for Blackwell)
  • CUTLASS FMHA reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py