Files
nvfp4-megamoe-kernel/CURRENT_BUG.md
biondizzle 914d27fee7 Update README + CURRENT_BUG: full CuTeDSL NVFP4 plan, no more PyTorch fallbacks
Mike's directive: build the full thing with NVFP4/CuTeDSL.
No more 'optimize later' or 'just make it work' workarounds.

Key updates:
- README: full architecture docs (CSA/HCA/mHC), current status, NVFP4 coverage
- CURRENT_BUG: detailed plan for CuTeDSL NVFP4 attention, KV cache, RoPE
- Both files document: checkpoint key names, compress ratios, config issues
- Removed all 'TODO: optimize later' hedging — we build it right the first time
2026-05-19 08:26:16 +00:00

5.2 KiB
Raw Blame History

CURRENT_BUG.md

Status: vLLM server starts but returns empty output (immediate EOS)

Root Cause

vLLM's compiled CUDA kernels don't work on Blackwell (SM100):

  1. torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_kv_rope_quant_insert — fused RoPE + FP8 KV cache write (C++ kernel)
  2. FlashMLA sparse attention — the attention computation (CUDA kernel)

Both crash or produce garbage on SM100. The model outputs immediate EOS → empty chat completions.

What Works (verified with standalone tests on B200)

CuTeDSL NVFP4 kernels — ALL PASS (cosine 0.9890.999 vs BF16):

q_a_proj:     0.995 ✅   kv_proj:      0.995 ✅   q_b_proj:     0.995 ✅
wo_b_proj:    0.995 ✅   comp.kv_proj: 0.994 ✅   comp.gate:    0.995 ✅
shared_expert: 0.990 ✅

Full attention path with SDPA — cosine 0.988 vs BF16, logit std 2.98

Warmup gs is IRRELEVANT — CuTeDSL runner recomputes activation global scale per-call internally. Changing it 10x has zero effect on output (cosine 0.9993).

Current Workaround (TEMPORARY — needs replacement)

_attention_impl_blackwell() in vllm/patches/deepseek_v4_attention.py:

  • Replaces FlashMLA with full_sdpa_attention()pure PyTorch matmuls, NO CuTeDSL
  • Replaces C++ fused kernel with fused_qnorm_rope_kv_insert_py() — pure PyTorch RoPE
  • Skips SWA KV cache write (cache uses fp8_ds_mla packed format, shape [slot, 37376] not [slot, 512])

THE PLAN: Replace all pure PyTorch with CuTeDSL/NVFP4

Mike's directive: NO "optimize later" BS. Build the full thing with NVFP4/CuTeDSL.

The attention Q×K and attn×V are activation×activation matmuls. They CAN be done in NVFP4:

  • Quantize Q and K to NVFP4 (4-bit activation + block scales)
  • Use CuTeDSL grouped GEMM for the sparse attention pattern
  • This is exactly what FlashMLA does with FP8 — we just use NVFP4 instead

Specific replacements needed:

  1. Attention (Q×K, attn×V) → CuTeDSL NVFP4 GEMM

    • Quantize Q and K to NVFP4 per-head
    • Use CuTeDSLNvfp4Linear or raw CuTeDSL GEMM for the matmuls
    • Support both prefill (batched) and decode (single-token) paths
    • Handle CSA sparse gather pattern (only attend to top-k positions)
  2. KV cache write → NVFP4 quant + paged cache insert

    • The SWA cache uses fp8_ds_mla format: packed FP8 values + UE8M0 scales
    • Row width = 37376 bytes (not just head_dim=512)
    • Layout: [nope_dim FP8 values | rope_dim FP8 values | scale blocks]
    • Need to understand the exact fp8_ds_mla layout and replicate in CuTeDSL
    • OR: skip SWA cache entirely and use our own NVFP4 cache format
  3. RoPE → CuTeDSL fused kernel

    • Currently pure PyTorch (works, but slow)
    • Could fuse with Q norm + KV quant into a single CuTeDSL kernel
    • Pattern: Q norm → RoPE → NVFP4 quant → all in one pass
  4. CSA sparse gather → CuTeDSL indexed access

    • Currently uses torch.gather (slow, not GPU-optimal)
    • CuTeDSL can do the gather + GEMM in one fused operation
    • This is the whole point of CSA — sparse KV access
  5. Compressor → already Triton (works on SM100)

  6. Indexer → already Triton (works on SM100)

  7. MHC → already pure PyTorch

  8. MoE → already CuTeDSL

Config Issues (from config.json)

  • quant_method: modelopt → vLLM uses ModelOpt's NVFP4 handler
  • Our CuTeDSL IS registered in _POSSIBLE_NVFP4_KERNELS (via register_cutedsl_kernel.py)
  • Added VLLM_NVFP4_GEMM_BACKEND=cutedsl env var to force it
  • kv_cache_scheme: {"num_bits": 8, "type": "float"} → FP8 KV cache → FlashMLA
  • Hard assertion issubclass(get_attn_backend(), FlashMLASparseBackend) — patched with _is_blackwell flag

Checkpoint Key Names (different from vLLM names!)

q_a_proj, q_b_proj, kv_proj       (NOT fused_wqa_wkv, wq_b)
q_a_norm                           (NOT q_norm)
attn_hc.fn/base/scale              (MHC attention)
ffn_hc.fn/base/scale               (MHC FFN)
compressor.kv_proj, compressor.gate_proj  (CSA/HCA)
compressor.position_bias
sinks                              (attn_sink)

Compress Ratios (from config.json compress_ratios)

Layer 0:  128 (HCA)    Layer 1:  128 (HCA)
Layer 2:  4   (CSA)    Layer 3:  128 (HCA)
Layer 4:  4   (CSA)    ...
...alternating 4/128...
Layer 60: 0   (SWA-only)

Architecture: CSA + HCA + mHC (NOT MLA!)

  • CSA (Compress Ratio 4): Compressed Sparse Attention — KV compressed 4x with overlap (coff=2)
  • HCA (Compress Ratio 128): Heavily Compressed Attention — KV compressed 128x
  • mHC: Manifold-Constrained Hyper-Connections — replaces standard residual connections
  • SWA: Sliding Window Attention (compress_ratio=0, last layer only)

Files

  • Kernel: cutedsl/csa_attention.py — CSA/HCA attention (currently SDPA, needs CuTeDSL)
  • vLLM patch: vllm/patches/deepseek_v4_attention.py_attention_impl_blackwell()
  • vLLM patch: vllm/patches/layers/csa_attention.pyfused_qnorm_rope_kv_insert_py(), full_sdpa_attention()
  • Standalone tests: tests/test_full_layer_b200.py, tests/test_csa_attention_b200.py, tests/test_model_forward_b200.py
  • CuTeDSL NVFP4 linear: cutedsl/nvfp4_linear.py — CuTeDSLNvfp4Linear runner
  • CuTeDSL bridge: cutedsl/bridge.py — quantize_activation_nvfp4, NVFP4 GEMM wrappers