biondizzle eef0ef76af Fix NVFP4 compressor scale loading: buffer and concatenate scale shards
The stacked params mapping (wkv + wgate → fused_wkv_wgate) uses
weight_loader(param, weight, shard_id), but PerTensorScaleParameter
and ModelWeightParameter for NVFP4 scale params don't support shard_id
in load_column_parallel_weight (asserts shape equality).

Fix: buffer input_scale, weight_scale, weight_scale_2 for fused_wkv_wgate
shards, then concatenate along dim 0 and copy_ into the param after all
weights are loaded.
2026-05-18 23:24:08 +00:00

NVFP4 MegaMoE Kernel

Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, and attention projections — runs in native NVFP4 with zero dequantization overhead.

What This Is

A native NVFP4 inference stack for DeepSeek-V4:

MoE Experts — CuTeDSL ScaledGroupedGemmKernel (our work):

BF16 input → quantize to NVFP4
  L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
  SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
  Re-quantize → NVFP4
  L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
  Scatter with routing weights → BF16 output

Attention Projections — CuTeDSL NVFP4 GEMM (our work, in progress):

  • wq_b, wo_b, fused_wqa_wkv — native NVFP4, no conversion
  • wo_a — NVFP4→FP8 for fp8_einsum (only attention weight that needs conversion)
  • Compressor — BF16 (weight_loader stacking issue, small matmul)

Shared Experts — CuTeDSL NVFP4 GEMM (our work, in progress):

  • gate_up_proj, down_proj — native NVFP4

Both GEMM types use float4_e2m1fn_x2 for weights, float8_e4m3fn for block scales, float32 for global scales. BF16 is used only for SiLU activation, the final MoE scatter, and the compressor — the minimum possible.

Current Status: Building Our Own Kernels 🔧

vLLM's built-in FlashInferCutlassNvFp4LinearKernel is broken on B200. The same class of C++ CUTLASS FP4 bugs we hit with MoE (documented in "How We Got Here") affects the attention and shared expert paths. After a full day of debugging (broken input_scale, process_weights_after_loading timing, forward hook failures, BF16 dequant workarounds), we're replacing ALL vLLM NVFP4 kernels with our own CuTeDSL implementations.

Test Results:

  • tests/layertest.py: cosine 0.988 vs BF16 reference
  • tests/cudagraph_test.py: capture + replay PASS
  • vLLM inference: produces empty/garbage output — vLLM's pipeline is broken, not our kernel

What works:

  • MoE expert CuTeDSL kernel — production-ready, cosine 0.988, cudagraph-safe
  • All NVFP4 weight dequantization — valid BF16 output confirmed in standalone tests

What's in progress:

  • Shared expert CuTeDSL kernel — runner WIP, scale assembly for num_groups=1
  • Attention projection CuTeDSL kernel — planned after shared experts

Why we're building our own: vLLM's FlashInferCutlassNvFp4LinearKernel uses the same C++ CUTLASS FP4 path that was broken for MoE. The CuTeDSL approach (Python-based CUTLASS via MLML→PTX) is what NVIDIA's CUTLASS team recommends for Blackwell. Our MoE kernel proves it works. Time to apply the same approach to the rest of the model.

vLLM serves DeepSeek-V4-Pro NVFP4 with cudagraph enabled. The model loads, cudagraph captures successfully, and inference runs. Output quality is still being tuned (garbage tokens currently), but this is the first time the entire pipeline — model loading, kernel compilation, cudagraph capture, and inference — works end-to-end.

Test Results:

  • tests/layertest.py: cosine 0.988 vs BF16 reference
  • tests/cudagraph_test.py: capture + replay PASS
  • vLLM inference: running with cudagraph, output quality in progress

How We Got Here

The C++ CUTLASS Kernel Was Broken

The original kernel was a C++ .cu file using CUTLASS's C++ API directly. It passed all the simple tests (uniform data → exact output, SF remap verifier → 0 errors) but produced cosine 0.05 with real random data. After weeks of debugging the SF remap (8+ iterations, all producing the same 0.2 cosine against a wrong reference), we discovered:

  1. The BF16 reference comparison was wrong — our Python dequantization didn't match CUTLASS's internal FP4 handling. A wrong reference is worse than no reference. We chased ghosts through 8+ SF remap rewrites because the 0.2 cosine was never about the remap.

  2. The C++ CUTLASS kernel misinterpreted FP4 data — even with SF remap verified correct (0 byte errors), the GEMM produced garbage with non-uniform data. The issue was in how CUTLASS's C++ API handles FP4 packing/tiling internally — something we couldn't easily debug or fix.

  3. The checkpoint input_scale was a red herring — we tried using the checkpoint's calibration scale as the activation normalization scale. It saturated all block scales to 448.0 (max float8). The input_scale is a calibration constant for alpha computation, not a normalization scale.

The CuTeDSL Kernel Works

NVIDIA's CuTeDSL approach (Python-based CUTLASS kernels compiled via MLML → PTX) is what the CUTLASS team recommends for Blackwell. Their official MoE scaled grouped GEMM example (torch_scaled_grouped_mm.py) supports NVFP4 out of the box. We adapted it.

Results with real DeepSeek-V4 layer 0 weights:

  • L1 GEMM alone: cosine 0.995
  • Full MoE pipeline (L1→SiLU→L2→scatter): cosine 0.989
  • Weight loading: 0% loss — direct uint8→float4_e2m1fn_x2 view-cast, bit-identical to checkpoint
  • Activation quantization: ~1.1% cosine loss (dynamic BF16→NVFP4 — inherent to the format, unavoidable)
  • GEMM kernel: 0% loss (CuTeDSL is correct)

The 0.989 cosine is entirely from activation quantization. The weights are bit-identical to the checkpoint — no BF16 round-trip, no precision loss.

The Dequant→Requant Anti-Pattern

Early versions dequantized all NVFP4 weights to BF16, then let vLLM's FlashInferCutlassNvFp4LinearKernel requantize them back to NVFP4 at inference time. This:

  • Wasted 5 minutes on load doing NVFP4→BF16 conversion
  • Lost precision on the double round-trip
  • Caused vLLM to hang — the NVFP4 attention kernel expects native NVFP4 weights, not BF16 weights with an NVFP4 quant_method attached

The fix: keep everything in NVFP4. The checkpoint stores NVFP4. The kernels consume NVFP4. No conversion needed.

CUDAGraph Compatibility

vLLM uses CUDA graphs to eliminate kernel launch overhead in the decode path. CUDA graphs record the entire forward pass once, then replay it — but they require fixed tensor shapes, fixed memory addresses, and zero CPU-GPU syncs.

Our original runner was not cudagraph-safe. We had to fix several classes of issues:

1. CPU↔CUDA Tensor Copies

torch.tensor([0,1,...], device=x.device) creates the tensor on CPU first, then copies to CUDA. This copy is forbidden during graph capture. The fix: cache tensors per device on first use, outside the graph.

# BAD — CPU→CUDA copy inside graph
step_to_idx = torch.tensor([0,1,2,3,4,4,5,5,6,6,6,7,7], device=x.device)

# GOOD — cached on first use, reused in graph
step_to_idx = _get_step_to_idx_lut(x.device)  # returns cached CUDA tensor

Similarly, torch.zeros and torch.rand don't support float4_e2m1fn_x2 or float8_e4m3fn dtypes. The fix: create as uint8 or float16, then .view() or .to() the target dtype.

2. GPU Scalar Slicing

buf[:gpu_scalar, :] requires the runtime to query the GPU scalar's value to determine the output shape. This triggers an implicit CPU-GPU sync, which invalidates the graph. The fix: always use full pre-allocated buffers. Extra rows are zeros that contribute nothing to the computation.

# BAD — GPU scalar as slice index (implicit sync)
total_padded_rows = padded_expert_offsets[-1]  # GPU scalar
padded_scales = buf[:total_padded_rows, :padded_cols]  # sync!

# GOOD — full pre-allocated buffer, zero out before use
padded_scales = self._padded_scales_buf  # always max size
padded_scales.zero_()

Design decision: Padding to max size wastes a few rows of compute on zero data, but:

  • The extra rows are zeros → zero GEMM output → no accuracy impact
  • GEMMs are memory-bandwidth bound → multiplying zeros is nearly free
  • VRAM cost is negligible (~350KB for activation intermediates across all MoE layers)
  • vLLM already does this everywhere (attention, FFN, etc.)

3. Kernel Compilation in the Forward Path

cute.compile() is a host-side JIT operation that generates PTX and compiles a CUDA kernel. It cannot be called inside cudagraph capture. The fix: compile once during warmup, cache the compiled kernel, then only invoke compiled() on subsequent calls.

The compiled kernel uses separate_tensormap_init=True, which handles TMA descriptor re-initialization for new tensor data. We create new mark_layout_dynamic CuTe tensor views for each forward call, and the compiled kernel accepts them.

Critical lesson: Caching the compiled kernel across different tensor allocations initially produced wrong results (cosine 0.5 instead of 0.99). The issue was NOT that caching is fundamentally broken — it was that our bridge had other bugs (wrong make_b_k_major stride check, quantize_weight_to_nvfp4 packing N instead of K). Once those were fixed, cached compilation works correctly.

4. Weight Quantization: K is the Packed Dimension

quantize_weight_to_nvfp4 packs K (dim 0) differently from quantize_to_nvfp4 which packs the last dimension. For a weight matrix (K, N):

  • K=7168 is the packed dimension (7168 → 3584 in float4)
  • N=6144 stays as-is
  • Block scales are computed along K blocks: (K//16, N) not (K//2, N//16)
  • The nibble packing uses [:, ::2, :] and [:, 1::2, :] (along the K block dim)

Confusing the two quantization functions produces wrong tensor shapes that crash or produce garbage.

5. B Tensor K-Major Layout

The CuTeDSL kernel expects B tensors in K-major memory layout (K elements contiguous in memory). torch.stack produces N-major layout. The fix: double-permute trick — transpose, make contiguous, transpose back. Same shape, different strides.

# Double-permute: (E,K,N) → (E,N,K) → contiguous → (E,K,N)
# Same shape, but K-contiguous memory layout
return b_tensor.permute(0, 2, 1).contiguous().permute(0, 2, 1)

A single permute changes the tensor SHAPE (swapping K and N), which breaks everything downstream.

Key Lessons

  1. A wrong reference is worse than no reference — the 0.2 cosine against a broken BF16 dequant sent us chasing SF remap bugs for weeks
  2. The C++ CUTLASS API is a footgun for FP4 — CuTeDSL handles tensor layouts, tiling, and SF construction correctly by construction
  3. Test with real data early — uniform tests pass even with broken kernels; random data reveals real bugs
  4. Separate the GEMM from the pipeline — our layertest.py runs without vLLM, Docker, or tensor parallelism. It caught the kernel bug that vLLM's integration layers masked.
  5. Don't dequant what's already quantized — if the kernel expects NVFP4 and the checkpoint is NVFP4, leave it alone. No BF16 round-trips.
  6. GPU scalar slicing is a silent cudagraph killer — no error, no warning, just cudaErrorStreamCaptureInvalidated with no pointer to the cause. The test harness catches it.
  7. Weight vs activation quantization are different — K-packed (weights) vs last-dim-packed (activations). Mixing them up produces wrong shapes and garbage output.
  8. Double-permute for memory layout changes — single permute changes shape, double-permute changes layout. The kernel cares about both.

Project Structure

nvfp4-megamoe-kernel/
├── cutedsl/                          # CuTeDSL kernel + bridge layer
│   ├── bridge.py                     # Tensor layout conversion, quantization, kernel launch
│   ├── moe_pipeline.py              # Full MoE pipeline (L1→SiLU→L2→scatter)
│   └── kernel/moe/                   # NVIDIA's ScaledGroupedGemmKernel (untouched)
│       ├── torch_scaled_grouped_mm.py   # The working kernel (3900 lines)
│       ├── moe_utils.py
│       moe_persistent_scheduler.py
│       └── moe_sched_extension.py
├── vllm/                             # vLLM integration
│   ├── nvfp4_cutedsl.py             # CuTeDSLMoERunner — cudagraph-safe MoE kernel interface
│   └── patches/
│       ├── deepseek_v4.py           # DeepSeek-V4 model patch (NVFP4 native)
│       └── deepseek_v4_attention.py # Attention patch (NVFP4 native)
├── tests/
│   ├── cudagraph_test.py            # CUDAGraph compatibility test (✅ PASS)
│   ├── layertest.py                 # Layer 0 comparison: CuTeDSL vs BF16 (✅ cosine 0.988)
│   ├── test_cutedsl.py              # Small standalone CuTeDSL test (✅ cosine 0.991)
│   ├── test_uniform_fp4.py          # Uniform data GEMM test
│   ├── test_b_layout.py             # B matrix column layout test
│   └── test_quick_rand.py           # Quick random GEMM sanity check
└── reference/                        # Reference files for study

The Bridge Layer (cutedsl/bridge.py)

Handles all tensor layout conversion from our pipeline to what the CuTeDSL kernel expects:

Function What it does
quantize_to_nvfp4() BF16 → float4_e2m1fn_x2 + float8_e4m3fn block scales (NOT cudagraph-safe — uses .max())
quantize_activation_nvfp4() Same, but cudagraph-safe (fixed global_scale, no .max())
quantize_weight_to_nvfp4() Same, but K is the packed dimension (different block scale shape)
assemble_scales_2d_side() Pad and swizzle activation scale factors (2Dx3D A side)
assemble_scales_3d_side() Pad and swizzle weight scale factors (2Dx3D B side)
make_b_k_major() Convert B tensor from N-major to K-major strides (double-permute trick)
run_nvfp4_grouped_gemm() Kernel launch with cached compilation (cudagraph-safe)

Running Tests

On the B200 (host venv, no container):

cd /root/nvfp4-megamoe-kernel
source tests/.venv/bin/activate
export CUDA_TOOLKIT_PATH=/usr/local/cuda

# CUDAGraph compatibility test
python3 tests/cudagraph_test.py

# Small standalone test
python3 tests/test_cutedsl.py

# Full layer 0 comparison with real weights
python3 tests/layertest.py

NVFP4 Coverage

Component Format Kernel Status
MoE experts (L1+L2) NVFP4 native CuTeDSL ScaledGroupedGemm Working, cosine 0.988
Shared experts NVFP4 native CuTeDSL GEMM (1 group) 🔧 In progress
wq_b, wo_b, fused_wqa_wkv NVFP4 native CuTeDSL GEMM 📋 Planned
wo_a NVFP4 → FP8 fp8_einsum 📋 May stay FP8 or go native NVFP4
Compressor NVFP4 → BF16 torch.mm Done (weight_loader stacking issue)
KV cache FP8 FlashInfer MLA Works

Plan

Phase 1: Kernel DONE

  • CuTeDSL ScaledGroupedGemmKernel works with NVFP4
  • Bridge layer handles all tensor layout conversion
  • Full MoE pipeline (L1→SiLU→L2→scatter) produces cosine 0.988 vs BF16

Phase 2: vLLM Integration DONE

  • CuTeDSLMoERunner wires CuTeDSL kernel into vLLM
  • Weight loading: checkpoint uint8 → float4_e2m1fn_x2 view-cast (bit-preserving)
  • Block scales (float8_e4m3fn) and global scales (float32) pass through directly
  • L1 dual global scale handling: normalize to max(gate_gs, up_gs), fold ratio into block scales
  • Attention projections stay native NVFP4 (FlashInferCutlassNvFp4LinearKernel)
  • CuTeDSL kernel warmup during model load (prevents RPC timeout)
  • Removed all debug prints and env var gates from vLLM serving path

Phase 2.5: CUDAGraph Compatibility DONE

  • CuTeDSLMoERunner is fully cudagraph-safe
  • Zero CPU-GPU syncs, zero dynamic shapes, zero GPU scalar slicing
  • All intermediate buffers pre-allocated at max_num_tokens * top_k
  • quantize_activation_nvfp4 uses cached LUT (no CPU→CUDA copy)
  • torch.zeros/rand for float4/float8 → uint8→view or float16→cast
  • Test harness validates capture + replay
  • VRAM overhead: ~350KB (negligible)
  • Compute overhead: zero rows through GEMM on padding (memory-bound, free)
  • Kernel compilation cached: cute.compile() once during warmup, compiled() on forward calls

Phase 3: Output Quality 🔧 IN PROGRESS

  • vLLM serves the model with cudagraph, but output is garbage tokens
  • Layer 0 cosine is 0.988 in isolation, so the GEMM math is correct
  • Root cause: vLLM's FlashInferCutlassNvFp4LinearKernel is broken on B200
    • Same class of C++ CUTLASS FP4 bugs we hit with MoE
    • input_scale from checkpoint causes NaN during activation quantization
    • BF16 dequant workaround doesn't fix the underlying pipeline issues
  • Solution: Replace ALL vLLM NVFP4 kernels with our own CuTeDSL implementations

Phase 3.5: Shared Expert CuTeDSL Kernel 🔧 IN PROGRESS

  • Replacing FlashInferCutlassNvFp4LinearKernel with CuTeDSL GEMM
  • Shared expert = MoE with 1 expert, no routing
  • cutedsl/shared_expert_pipeline.py — dedicated runner (scale assembly needs fixing)
  • tests/test_shared_expert.py — standalone test (ready)
  • Target: cosine ≥ 0.98 vs BF16 reference

Phase 3.6: Attention CuTeDSL Kernel 📋 PLANNED

  • Replace attention NVFP4 path with CuTeDSL GEMMs
  • Each projection (fused_wqa_wkv, wq_b, wo_a, wo_b) = standard NVFP4 GEMM
  • fused_wqa_wkv has dual weight_scale_2 (same pattern as MoE gate+up)
  • Test each projection individually, then integrate

Phase 4: Optimization

  • Replace wo_a FP8 conversion with native NVFP4 GEMM (eliminate last dequant)
  • Fix compressor weight_loader so it stays NVFP4 native
  • Explore larger tile sizes for better occupancy
  • Profile end-to-end inference on full model
  • Proper TMA descriptor management for kernel caching across different tensor pools

Phase 5: Production

  • Clean up old C++ kernel code (tagged the-last-of-cutlass)
  • Add proper error handling and logging
  • Benchmark vs BF16 baseline
Description
No description provided
Readme 13 MiB
Languages
Python 74.9%
Cuda 25%