Mike's directive: build the full thing with NVFP4/CuTeDSL. No more 'optimize later' or 'just make it work' workarounds. Key updates: - README: full architecture docs (CSA/HCA/mHC), current status, NVFP4 coverage - CURRENT_BUG: detailed plan for CuTeDSL NVFP4 attention, KV cache, RoPE - Both files document: checkpoint key names, compress ratios, config issues - Removed all 'TODO: optimize later' hedging — we build it right the first time
5.2 KiB
CURRENT_BUG.md
Status: vLLM server starts but returns empty output (immediate EOS)
Root Cause
vLLM's compiled CUDA kernels don't work on Blackwell (SM100):
torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_kv_rope_quant_insert— fused RoPE + FP8 KV cache write (C++ kernel)- FlashMLA sparse attention — the attention computation (CUDA kernel)
Both crash or produce garbage on SM100. The model outputs immediate EOS → empty chat completions.
What Works (verified with standalone tests on B200)
CuTeDSL NVFP4 kernels — ALL PASS (cosine 0.989–0.999 vs BF16):
q_a_proj: 0.995 ✅ kv_proj: 0.995 ✅ q_b_proj: 0.995 ✅
wo_b_proj: 0.995 ✅ comp.kv_proj: 0.994 ✅ comp.gate: 0.995 ✅
shared_expert: 0.990 ✅
Full attention path with SDPA — cosine 0.988 vs BF16, logit std 2.98 ✅
Warmup gs is IRRELEVANT — CuTeDSL runner recomputes activation global scale per-call internally. Changing it 10x has zero effect on output (cosine 0.9993).
Current Workaround (TEMPORARY — needs replacement)
_attention_impl_blackwell() in vllm/patches/deepseek_v4_attention.py:
- Replaces FlashMLA with
full_sdpa_attention()— pure PyTorch matmuls, NO CuTeDSL - Replaces C++ fused kernel with
fused_qnorm_rope_kv_insert_py()— pure PyTorch RoPE - Skips SWA KV cache write (cache uses fp8_ds_mla packed format, shape [slot, 37376] not [slot, 512])
THE PLAN: Replace all pure PyTorch with CuTeDSL/NVFP4
Mike's directive: NO "optimize later" BS. Build the full thing with NVFP4/CuTeDSL.
The attention Q×K and attn×V are activation×activation matmuls. They CAN be done in NVFP4:
- Quantize Q and K to NVFP4 (4-bit activation + block scales)
- Use CuTeDSL grouped GEMM for the sparse attention pattern
- This is exactly what FlashMLA does with FP8 — we just use NVFP4 instead
Specific replacements needed:
-
Attention (Q×K, attn×V) → CuTeDSL NVFP4 GEMM
- Quantize Q and K to NVFP4 per-head
- Use
CuTeDSLNvfp4Linearor raw CuTeDSL GEMM for the matmuls - Support both prefill (batched) and decode (single-token) paths
- Handle CSA sparse gather pattern (only attend to top-k positions)
-
KV cache write → NVFP4 quant + paged cache insert
- The SWA cache uses
fp8_ds_mlaformat: packed FP8 values + UE8M0 scales - Row width = 37376 bytes (not just head_dim=512)
- Layout: [nope_dim FP8 values | rope_dim FP8 values | scale blocks]
- Need to understand the exact fp8_ds_mla layout and replicate in CuTeDSL
- OR: skip SWA cache entirely and use our own NVFP4 cache format
- The SWA cache uses
-
RoPE → CuTeDSL fused kernel
- Currently pure PyTorch (works, but slow)
- Could fuse with Q norm + KV quant into a single CuTeDSL kernel
- Pattern: Q norm → RoPE → NVFP4 quant → all in one pass
-
CSA sparse gather → CuTeDSL indexed access
- Currently uses
torch.gather(slow, not GPU-optimal) - CuTeDSL can do the gather + GEMM in one fused operation
- This is the whole point of CSA — sparse KV access
- Currently uses
-
Compressor → already Triton (works on SM100) ✅
-
Indexer → already Triton (works on SM100) ✅
-
MHC → already pure PyTorch ✅
-
MoE → already CuTeDSL ✅
Config Issues (from config.json)
quant_method: modelopt→ vLLM uses ModelOpt's NVFP4 handler- Our CuTeDSL IS registered in
_POSSIBLE_NVFP4_KERNELS(via register_cutedsl_kernel.py) - Added
VLLM_NVFP4_GEMM_BACKEND=cutedslenv var to force it kv_cache_scheme: {"num_bits": 8, "type": "float"}→ FP8 KV cache → FlashMLA- Hard assertion
issubclass(get_attn_backend(), FlashMLASparseBackend)— patched with_is_blackwellflag
Checkpoint Key Names (different from vLLM names!)
q_a_proj, q_b_proj, kv_proj (NOT fused_wqa_wkv, wq_b)
q_a_norm (NOT q_norm)
attn_hc.fn/base/scale (MHC attention)
ffn_hc.fn/base/scale (MHC FFN)
compressor.kv_proj, compressor.gate_proj (CSA/HCA)
compressor.position_bias
sinks (attn_sink)
Compress Ratios (from config.json compress_ratios)
Layer 0: 128 (HCA) Layer 1: 128 (HCA)
Layer 2: 4 (CSA) Layer 3: 128 (HCA)
Layer 4: 4 (CSA) ...
...alternating 4/128...
Layer 60: 0 (SWA-only)
Architecture: CSA + HCA + mHC (NOT MLA!)
- CSA (Compress Ratio 4): Compressed Sparse Attention — KV compressed 4x with overlap (coff=2)
- HCA (Compress Ratio 128): Heavily Compressed Attention — KV compressed 128x
- mHC: Manifold-Constrained Hyper-Connections — replaces standard residual connections
- SWA: Sliding Window Attention (compress_ratio=0, last layer only)
Files
- Kernel:
cutedsl/csa_attention.py— CSA/HCA attention (currently SDPA, needs CuTeDSL) - vLLM patch:
vllm/patches/deepseek_v4_attention.py—_attention_impl_blackwell() - vLLM patch:
vllm/patches/layers/csa_attention.py—fused_qnorm_rope_kv_insert_py(),full_sdpa_attention() - Standalone tests:
tests/test_full_layer_b200.py,tests/test_csa_attention_b200.py,tests/test_model_forward_b200.py - CuTeDSL NVFP4 linear:
cutedsl/nvfp4_linear.py— CuTeDSLNvfp4Linear runner - CuTeDSL bridge:
cutedsl/bridge.py— quantize_activation_nvfp4, NVFP4 GEMM wrappers