biondizzle 914d27fee7 Update README + CURRENT_BUG: full CuTeDSL NVFP4 plan, no more PyTorch fallbacks
Mike's directive: build the full thing with NVFP4/CuTeDSL.
No more 'optimize later' or 'just make it work' workarounds.

Key updates:
- README: full architecture docs (CSA/HCA/mHC), current status, NVFP4 coverage
- CURRENT_BUG: detailed plan for CuTeDSL NVFP4 attention, KV cache, RoPE
- Both files document: checkpoint key names, compress ratios, config issues
- Removed all 'TODO: optimize later' hedging — we build it right the first time
2026-05-19 08:26:16 +00:00

NVFP4 MegaMoE Kernel

Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, attention projections, and attention compute — runs in native NVFP4 with zero dequantization overhead.

What This Is

A native NVFP4 inference stack for DeepSeek-V4:

MoE Experts — CuTeDSL ScaledGroupedGemmKernel :

BF16 input → quantize to NVFP4
  L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
  SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
  Re-quantize → NVFP4
  L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
  Scatter with routing weights → BF16 output

Attention Projections — CuTeDSL NVFP4 GEMM :

  • q_a_proj, q_b_proj, kv_proj, wo_b_proj — native NVFP4, cosine 0.995 vs BF16
  • wo_a — BF16 BMM (o_a_proj weights are BF16 in checkpoint)
  • compressor.kv_proj, compressor.gate_proj — native NVFP4, cosine 0.995 vs BF16
  • All verified with tests/test_full_layer_b200.py

Shared Experts — CuTeDSL NVFP4 GEMM :

  • gate_up_proj, down_proj — native NVFP4, cosine 0.990 vs BF16

Attention ComputeNEEDS CuTeDSL NVFP4 🔧:

  • Currently using pure PyTorch SDPA as a TEMPORARY workaround
  • Q×K and attn×V are activation×activation matmuls that CAN be NVFP4
  • FlashMLA (vLLM's CUDA kernel) is broken on Blackwell
  • Plan: CuTeDSL NVFP4 attention kernel — quantize Q/K to NVFP4, use CuTeDSL GEMMs

KV Cache WriteNEEDS CuTeDSL NVFP4 🔧:

  • The SWA KV cache uses fp8_ds_mla packed format (37376 bytes per slot, not 512)
  • C++ kernel fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert is broken on Blackwell
  • Currently skipped in Blackwell path (works for prefill, breaks decode)
  • Plan: NVFP4 quant + paged cache insert in CuTeDSL

Architecture: DeepSeek-V4-Pro

CSA + HCA + mHC (NOT MLA — vLLM misnames it "MLA" in code):

  • CSA (Compress Ratio 4): Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k.
  • HCA (Compress Ratio 128): Heavily Compressed Attention — KV compressed 128x. Top-k indices pre-computed during metadata build.
  • mHC: Manifold-Constrained Hyper-Connections — replaces standard residual connections. Learned mixing with Sinkhorn normalization.
  • SWA: Sliding Window Attention — local window (compress_ratio=0, last layer only)

Compress Ratios (from config.json):

Layer 0: 128 (HCA)   Layer 1: 128 (HCA)   Layer 2: 4 (CSA)   Layer 3: 128 (HCA)
Layer 4: 4 (CSA)     ...alternating 4/128...                  Layer 60: 0 (SWA)

Checkpoint Key Names (different from vLLM's internal names):

q_a_proj, q_b_proj, kv_proj       (NOT fused_wqa_wkv, wq_b)
q_a_norm                           (NOT q_norm)
attn_hc.fn/base/scale              (MHC attention)
ffn_hc.fn/base/scale               (MHC FFN)
compressor.kv_proj, compressor.gate_proj  (CSA/HCA compressor)
compressor.position_bias
sinks                              (attn_sink)

Current Status: Attention + KV Cache Need CuTeDSL 🔧

What works (verified on B200):

  • CuTeDSL NVFP4 linear kernels: cosine 0.9890.999 vs BF16
  • CuTeDSL NVFP4 MoE: cosine 0.988
  • Full attention path with PyTorch SDPA: cosine 0.988 vs BF16
  • MHC, RMS norm, RoPE (BF16), wo_a BMM, shared experts
  • Compressor + indexer (Triton, works on SM100)

What's broken:

  • FlashMLA CUDA kernel → garbage on Blackwell → model outputs immediate EOS
  • fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert C++ kernel → crashes on Blackwell
  • Pure PyTorch SDPA is a TEMPORARY workaround — must replace with CuTeDSL NVFP4

What needs to be built:

1. CuTeDSL NVFP4 Attention Kernel

  • Quantize Q and K to NVFP4 per-head
  • Use CuTeDSL GEMM for Q×K and attn×V
  • Support prefill (batched) and decode (single-token) paths
  • Handle CSA sparse gather (attend to top-k positions only)
  • This is exactly what FlashMLA does with FP8 — we just use NVFP4 instead
  • Test first: build standalone test in tests/ with real weights

2. CuTeDSL NVFP4 KV Cache Insert

  • Replace C++ fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert
  • Per-head RMS norm on Q + GPT-J RoPE on Q + RoPE on KV + NVFP4 quant + paged cache write
  • The SWA cache uses fp8_ds_mla packed format: row width = 37376 bytes
    • Layout: [nope_dim FP8 values | rope_dim FP8 values | UE8M0 scale blocks]
    • NOT just [head_dim] — it's a packed FP8 format with interleaved scales
  • Option A: Understand and write the fp8_ds_mla format from CuTeDSL
  • Option B: Use our own NVFP4 cache format (simpler, more efficient, but diverges from vLLM)

3. CuTeDSL Fused RoPE + Norm Kernel

  • Currently pure PyTorch (works, but slow)
  • Fuse: Q norm → RoPE → NVFP4 quant → all in one pass
  • Same for KV side: RoPE → NVFP4 quant → cache write

4. CuTeDSL CSA Sparse Gather

  • Currently torch.gather (slow, not GPU-optimal)
  • CuTeDSL can do the gather + GEMM in one fused operation
  • The whole point of CSA is sparse KV access — we should do it right

vLLM Integration

The Blackwell detection and dispatch is in vllm/patches/deepseek_v4_attention.py:

  • attention_impl() detects SM100+ → _attention_impl_blackwell()
  • Currently uses pure PyTorch (SDPA) — must replace with CuTeDSL
  • The dispatch is INSIDE the torch.ops.vllm.deepseek_v4_attention custom op boundary (important for torch.compile)

Config issues:

  • quant_method: modelopt → vLLM uses ModelOpt's NVFP4 handler
  • Our CuTeDSL IS registered (via register_cutedsl_kernel.py) and forced with VLLM_NVFP4_GEMM_BACKEND=cutedsl
  • FlashMLA hard assertion in DeepseekV4MLAAttention.__init__ — patched with _is_blackwell flag
  • kv_cache_scheme: {"num_bits": 8, "type": "float"} → FP8 KV cache → FlashMLA (broken on Blackwell)

Key discovery: warmup gs is irrelevant. CuTeDSL runner recomputes activation global scale per-call internally. Changing it 10x has zero effect on output (cosine 0.9993). The input_scale from the checkpoint is NOT the activation global scale — it's a calibration constant.

Test Files

Test What it does Status
tests/test_full_layer_b200.py All NVFP4 projections vs BF16 (layer 0) All pass (0.9890.999)
tests/test_model_forward_b200.py Warmup gs vs dynamic gs diagnostic Warmup gs irrelevant
tests/test_csa_attention_b200.py Full attention path with SDPA cosine 0.988
tests/layertest.py MoE layer test cosine 0.988
tests/cudagraph_test.py CUDAGraph compatibility PASS
tests/test_shared_expert.py Shared expert standalone cosine 0.990

Project Structure

nvfp4-megamoe-kernel/
├── cutedsl/                          # CuTeDSL kernel + bridge layer
│   ├── bridge.py                     # Tensor layout conversion, quantization, kernel launch
│   ├── nvfp4_linear.py              # CuTeDSLNvfp4Linear — NVFP4 GEMM runner
│   ├── moe_pipeline.py              # Full MoE pipeline (L1→SiLU→L2→scatter)
│   ├── shared_expert_pipeline.py    # Shared expert pipeline (1-expert MoE variant)
│   ├── csa_attention.py             # CSA/HCA attention (currently SDPA, needs CuTeDSL)
│   ├── custom_ops.py                # torch.autograd wrappers for compile boundary
│   └── kernel/moe/                   # NVIDIA's ScaledGroupedGemmKernel
├── vllm/                             # vLLM integration
│   ├── nvfp4_cutedsl.py             # CuTeDSLMoERunner — cudagraph-safe MoE kernel
│   ├── cutedsl_quant_method.py      # CuTeDSLNvfp4LinearMethod — vLLM quant method
│   ├── kernels/linear/nvfp4/cutedsl.py  # CuTeDSLNvFp4LinearKernel — vLLM kernel registration
│   └── patches/
│       ├── deepseek_v4.py           # Model patch (NVFP4 native, MHC, MoE)
│       ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch)
│       ├── layers/
│       │   ├── mhc.py               # MHC pure PyTorch (replaces TileLang)
│       │   ├── csa_attention.py     # CSA attention (TEMPORARY — needs CuTeDSL)
│       │   └── deepseek_compressor.py  # Compressor (Triton, works on SM100)
│       └── fused_moe/experts/cutedsl_moe.py  # MoE CuTeDSL integration
├── tests/                            # Standalone tests (run on B200 outside container)
└── Dockerfile                        # Container build

Plan

Phase 1: MoE Kernel DONE

  • CuTeDSL ScaledGroupedGemmKernel with NVFP4
  • Full pipeline: cosine 0.988, cudagraph-safe

Phase 2: NVFP4 Linear Kernels DONE

  • All attention projections: cosine 0.995
  • Shared experts: cosine 0.990
  • Compressor projections: cosine 0.995

Phase 3: vLLM Integration DONE (with PyTorch fallback)

  • CuTeDSL kernels registered and working for all NVFP4 linear layers
  • Blackwell dispatch in attention_impl
  • MHC pure PyTorch
  • MoE CuTeDSL

Phase 4: CuTeDSL NVFP4 Attention 🔧 NEXT

  • Replace pure PyTorch SDPA with CuTeDSL NVFP4 GEMMs for Q×K and attn×V
  • NVFP4 KV cache insert (replace C++ kernel)
  • Fused RoPE + norm + quant kernel
  • CSA sparse gather in CuTeDSL
  • Test each component standalone before integrating into vLLM

Phase 5: Production

  • End-to-end benchmarking
  • Optimize tile sizes for occupancy
  • Clean up old C++ kernel code
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%