biondizzle 07817ae82e FIX: use unsliced tBgK with (None, kt, None, 0) for proper GMEM tile indexing
The pre-slice (None,0,None,0) hardcoded GMEM iteration to tile 0.
Instead, keep the original tBgK and index with (None, kt, None, 0)
inside the TMA loop, where kt selects the correct GMEM tile.
This preserves 2D rank matching with the SMEM tensor.
2026-05-22 15:52:56 +00:00
2026-05-19 09:37:38 +00:00

DSV4 Inference Kernel

Architecture

DSV4 is not MLA. It uses CSA (Compressed Sparse Attention, m=4) and HCA (Heavily Compressed Attention, m=128). KV latent is (T, 512) shared across all 128 heads. Sink weights merge sparse + SWA attention. vLLM misnames this as "MLA" — it is not. The architecture is fundamentally different.

DSV4 inference pipeline — component status
==========================================

Legend:
 [✓] built and tested
 [~] partial — reference or seam exists, native pending
 [✗] to build


 ┌────────────────────────────────────┐
 │ [✗] Embedding + mHC init          │
 │ token embed + n_hc=4 streams      │
 └────────────────┬───────────────────┘
                  │
                  ▼
┌─ Transformer layer × L ──────────────────────────────────────────────┐
│ HCA on layers 01 of Pro, alternating CSA / HCA after              │
│                                                                      │
│ ┌─ Attention sub-block ──────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [~] Norms + RoPE             RMSNorm + partial RoPE           │  │
│ │ [✓] Q / KV projection        NVFP4 linears + LoRA             │  │
│ │ [~] Token compressor         CSA m=4 / HCA m=128             │  │
│ │ [✗] Indexer + top-k          CSA only, FP4 QK                 │  │
│ │ [~] FMHA core                QK → online softmax → PV         │  │
│ │                              + SWA branch + sink merge         │  │
│ │ [✓] Output projection        inv RoPE + wo_a grouped + wo_b   │  │
│ └────────────────────────────────────────────────────────────────┘  │
│                                                                      │
│ ┌─ FFN sub-block ────────────────────────────────────────────────┐  │
│ │ [✓] Residual mHC pre + post mix                               │  │
│ │ [~] Pre-FFN norm              RMSNorm                          │  │
│ │ [✗] Router                    sqrt(softplus) + topk + hash     │  │
│ │ [✓] Routed MoE               fused SwiGLU L1 + L2             │  │
│ │ [✓] Shared expert            NVFP4 single-group GEMM          │  │
│ └────────────────────────────────────────────────────────────────┘  │
└──────────────────────────────────┬───────────────────────────────────┘
                                  │
                                  ▼
┌──────────────────────────────────────────────────────────────────────┐
│ [✗] Final RMSNorm → [✗] LM head → [✗] MTP (depth=1) → [✗] Sampler │
└──────────────────────────────────────────────────────────────────────┘

┌─ Supporting infrastructure ──────────────────────────────────────────┐
│ [✗] KV cache management                                             │
│ • state cache: SWA window + uncompressed tail per layer             │
│ • classical paged cache: lcm(m, m) = 128 tokens per block         │
│ • heterogeneous layout per layer                                    │
└──────────────────────────────────────────────────────────────────────┘


Summary
-------
 Built  [✓] : 6 — mHC ×2, Q/KV proj, output proj, routed MoE,
               shared expert
 Partial [~] : 4 — norms+RoPE, token compressor, FMHA core,
               pre-FFN norm
 To build [✗] : 8 — embedding+init, indexer+top-k, router,
               final norm, LM head, MTP, sampler, KV cache

Status (May 22, 2026 — 09:40 UTC)

Stage Status Description
A COMPLETE Q@K^T via tcgen05.mma → TMEM → GMEM
B COMPLETE QK → identity softmax → P@V pipeline (TMEM alias, KV-tile interleaving)
C WORKING Real online softmax: row_max (fmax), exp2 scaling, P store, row_sum, O normalization. Cosine 0.993-0.996
C' 🔨 NEXT Cross-warp reduction, correction warps, 12-warp production pipeline, multi-tile KV
D TODO Full decode attention: paged KV cache, multi-query, causal mask
E TODO Production kernel: extract into dsv4/kernels/attention/, PyTorch custom op, vLLM bridge

Package Structure

dsv4/
├── kernels/          Pure GPU code (CuTeDSL @cute.jit, .cu files)
│   ├── gemm/           NVFP4 MoE GEMM kernels (grouped, fused_swiglu, dense, scheduler)
│   ├── attention/      FMHA kernel (stub — extraction is Stage E)
│   ├── compressor/     CSA/HCA token-level compressor
│   ├── decode/         Decode-time attention (sparse, SWA — future)
│   └── cuda/           Raw .cu files (deinterleave_quantize, sparse_topk_metadata)
├── ops/              PyTorch ↔ kernel bridges
│   ├── quantize.py      BF16 ↔ NVFP4 conversion, scale factors
│   ├── layouts.py       Scale swizzle, gate/up interleave, K-major, offsets
│   ├── gemm_runner.py   Warmup, compile, run grouped/fused GEMMs
│   ├── custom_ops.py    torch.library.custom_op registrations
│   ├── decode_sparse.py native_sparse_decode dispatcher
│   ├── decode_swa.py    native_swa_decode dispatcher
│   ├── rope.py          Forward + inverse RoPE
│   └── topk.py          Python wrapper for sparse_topk_metadata.cu
├── layers/           nn.Module-style components
│   ├── linear.py        Nvfp4Linear
│   ├── grouped_linear.py Nvfp4GroupedLinear
│   ├── moe.py           Nvfp4MoE
│   ├── shared_expert.py Nvfp4SharedExpert
│   ├── mhc.py           mHCLayer
│   └── (stubs: attention, ffn, router, norm, embedding)
├── model/            Model assembly (stubs — Phase 1)
├── cache/            KV cache infra (stubs — Phase 3)
├── loader/           Checkpoint I/O (stubs — Phase 1)
└── reference/        Slow PyTorch oracles (never imported by production code)
    ├── attention.py     RoPE, KV cache, causal attention, SWA
    ├── csa_attention.py CSA/HCA sparse attention
    ├── compressor.py    Compressor PyTorch example
    └── moe_pipeline.py  MoE pipeline reference

Mental model: kernels/ops/layers/model/ (dependency flows left to right). reference/ and loader/ are sidecars.


Active Test Files

FMHA (Stages A/B/C) — in tests/unit/

File Stage Status
test_fmha_v3.py A+B Full QK→identity softmax→PV, cosine 0.999999
test_fmha_v3_12w.py A+B 12-warp QK→PV, cosine 0.999999
test_fmha_v3_stage_c_full.py C Real online softmax + O normalization, cosine 0.993-0.996
test_fmha_v3_stage_c_min.py C 🔨 Early 12-warp pipeline (broken pipeline state)
test_pv64_with_softmax.py B (128,64) PV, single AB pipeline
test_128_128_vdiag.py A+B (128,128) PV baseline
test_qkonly.py A QK with split Q/KV pipelines
test_qk_softmax.py A+B QK + identity softmax, no PV

MoE / GEMM — in tests/unit/

File What
test_cutedsl.py NVFP4 grouped GEMM kernel
cudagraph_test.py Cudagraph capture + replay
layertest.py Per-layer correctness
test_custom_op.py torch.library custom ops
test_compile_custom_op.py Compile + warmup
test_fp4_roundtrip.py BF16 → NVFP4 → BF16 roundtrip
test_interleave.py Gate/up weight interleaving
test_interleave_gemm.py Interleaved GEMM correctness
test_fused_step1.py Fused SwiGLU GEMM

Archived Tests

tests/archive/ contains ~190 debug files from Stages A/B. Not maintained. Can be deleted.


Stage C: Online Softmax — WORKING

What We Have

Working real softmax in test_fmha_v3_stage_c_full.py: cosine 0.9930.996 across 3 seeds.

Current Architecture (6-warp)

Warps 0-3: Softmax + Epilogue — load S, real softmax, P store, O normalize, epilogue Warp 4: MMA (QK→S, PV→O) Warp 5: TMA (Q/K/V load)

Target Architecture (12-warp, production)

Warps 0-3: Softmax — S→softmax→P, broadcast vec=[old_max, new_max] Warps 4-7: Correction — O rescale (TMEM), final normalization, SMEM write Warp 8: MMA — QK→S, PV→O with pipeline chaining Warp 9: TMA — Q/K/V load Warp 10: Epilogue — O SMEM→GMEM via TMA Warp 11: Empty — tmem dealloc mbar init

Pipeline chain: MMA → Softmax → Correction → Epilogue (plus MMA → Correction)

CuTeDSL Constraints (hard-won)

  1. vectorize=True loops: ONLY load/store/print — no fmax, no cmpf, no inner loops, no carry
  2. .reduce(cute.ReductionOp.MAX): reduces ENTIRE C-fragment to scalar — global max, not per-row. Use cute.arch.fmax element-wise instead
  3. Dynamic control flow: variables need initial values BEFORE the flow starts
  4. cute.arch.fmax: impure for vectorizer — use plain range() loop
  5. Carry variables (row_max, row_sum): cannot use vectorize=True

Remaining for C' (Production Stage C)

  1. Cross-warp reduction for row_max and row_sum
  2. Correction warps for multi-tile KV (online O rescale in TMEM)
  3. 12-warp layout with separate softmax/correction/epilogue warps
  4. Per-row O normalization

TMEM Layout

Col 0-127: S (QK acc, 128 FP32) | Col 32-95: P (Softmax, 64 FP32) | Col 128+: O (PV acc, 64 FP32)

Row_max/row_sum are per-thread FP32 scalars. Correction warps will use TMEM-backed vec buffer.


Stage E: Production Kernel Extraction

When ready, extract from test_fmha_v3.pydsv4/kernels/attention/fmha.py:

  1. Clean FmhaKernel class with @cute.jit __call__, no hardcoded dimensions
  2. Add real softmax (Stage C)
  3. Add paged KV cache (Stage D)
  4. Wrap as torch.library.custom_op in dsv4/ops/
  5. Integrate with vLLM

Key Lessons

  1. NEVER use find_tmem_tensor_col_offset() as TMEM placement. It returns footprint size, not a safe offset.
  2. FMHA never trusts DLPack tensor layouts. Reconstruct V as (hd, s_k) MN-major inside CuTe.
  3. TMEM allocation must be power of 2.
  4. Square hides bugs. (128,128) worked for every wrong approach. Always test non-square.
  5. St32x32bOp MUST use Float32, NOT BFloat16. BFloat16 causes illegal memory access.
  6. First PV ACCUMULATE=False. Otherwise adds uninitialized TMEM to output.
  7. FMHA P store uses QK C-fragment composition, NOT PV A-fragment. Two aliases, same TMEM.
  8. Register bridge: FP32 backing (store partition) + BF16 view (QK-load layout). Do not skip this.

Environment

  • Server: root@45.76.247.107 (B200, 180 GiB HBM3e per GPU)
  • venv: source /root/dsv4-nvfp4-workspace/venv/bin/activate
  • PYTHONPATH: /root/dsv4-nvfp4-workspace/kernel
  • Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
  • vLLM repo: /root/dsv4-nvfp4-workspace/vllm (modified for Blackwell)
  • CUTLASS FMHA reference: /root/cutlass/examples/python/CuTeDSL/cute/blackwell/kernel/attention/fmha/fmha.py
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%