Update README and CURRENT_BUG: BUILD YOUR OWN KERNELS. Stop patching vLLM.
This commit is contained in:
191
CURRENT_BUG.md
191
CURRENT_BUG.md
@@ -1,116 +1,137 @@
|
||||
# CURRENT_BUG.md
|
||||
|
||||
## Status: vLLM server starts but returns empty output (immediate EOS)
|
||||
## Status: Container starts, model generates tokens, but output is GARBAGE (empty/NaN)
|
||||
|
||||
### Root Cause
|
||||
### THE FUNDAMENTAL PROBLEM
|
||||
|
||||
vLLM's compiled CUDA kernels don't work on Blackwell (SM100):
|
||||
1. **`torch.ops._C.fused_deepseek_v4_qnorm_rope_kv_rope_kv_rope_quant_insert`** — fused RoPE + FP8 KV cache write (C++ kernel)
|
||||
2. **FlashMLA sparse attention** — the attention computation (CUDA kernel)
|
||||
**Mike was right — we need our own kernels. Not just for the NVFP4 GEMMs, but for the ENTIRE attention pipeline. The current approach of patching individual vLLM functions is a house of cards.**
|
||||
|
||||
Both crash or produce garbage on SM100. The model outputs immediate EOS → empty chat completions.
|
||||
Here's what happened: we spent hours patching vLLM to "work" on Blackwell. We patched:
|
||||
1. `VLLM_NVFP4_GEMM_BACKEND=cutedsl` → invalid, removed env var
|
||||
2. KV cache page size assertion → patched `kv_cache_utils.py`
|
||||
3. 91 missing compressor cache layers → patched alignment in 3 cache specs
|
||||
4. `softmax_scale` AttributeError → fixed to `self.scale`
|
||||
5. NaN from missing RoPE on KV → added `_apply_rope_kv()`
|
||||
6. Shape mismatch in `apply_gptj_rope` → rewrote as inline RoPE
|
||||
|
||||
### What Works (verified with standalone tests on B200)
|
||||
After ALL of that, the container starts and runs. But the model output is GARBAGE — empty strings, NaN logprobs, zero meaningful text. Because the attention pipeline is fundamentally broken:
|
||||
- The KV cache is never written to (the C++ kernel is FlashMLA-only)
|
||||
- The attention reads from raw projection output, not the cache
|
||||
- The compressor/indexer runs but the Blackwell path doesn't integrate with it
|
||||
- Everything is deeply coupled — patch one thing, three more break
|
||||
|
||||
**CuTeDSL NVFP4 kernels — ALL PASS (cosine 0.989–0.999 vs BF16):**
|
||||
```
|
||||
q_a_proj: 0.995 ✅ kv_proj: 0.995 ✅ q_b_proj: 0.995 ✅
|
||||
wo_b_proj: 0.995 ✅ comp.kv_proj: 0.994 ✅ comp.gate: 0.995 ✅
|
||||
shared_expert: 0.990 ✅
|
||||
```
|
||||
**THE ONLY FIX: Build CuTeDSL kernels for the ENTIRE attention pipeline.**
|
||||
|
||||
**Full attention path with SDPA — cosine 0.988 vs BF16, logit std 2.98 ✅**
|
||||
Do NOT try to patch vLLM's FlashMLA code. Do NOT use pure PyTorch as a workaround. Do NOT skip the KV cache write. BUILD THE KERNELS.
|
||||
|
||||
**Warmup gs is IRRELEVANT** — CuTeDSL runner recomputes activation global scale per-call internally. Changing it 10x has zero effect on output (cosine 0.9993).
|
||||
### Container Crash History (May 19)
|
||||
|
||||
### Current Workaround (TEMPORARY — needs replacement)
|
||||
Each crash was "fixed" with a patch. Each patch led to the next crash. This is the house of cards:
|
||||
|
||||
`_attention_impl_blackwell()` in `vllm/patches/deepseek_v4_attention.py`:
|
||||
- Replaces FlashMLA with `full_sdpa_attention()` — **pure PyTorch matmuls, NO CuTeDSL**
|
||||
- Replaces C++ fused kernel with `fused_qnorm_rope_kv_insert_py()` — pure PyTorch RoPE
|
||||
- Skips SWA KV cache write (cache uses fp8_ds_mla packed format, shape [slot, 37376] not [slot, 512])
|
||||
1. `VLLM_NVFP4_GEMM_BACKEND=cutedsl` — invalid choice in `envs.py` → removed env var
|
||||
2. `assert max(sm_page_sizes) <= max(all_page_sizes)` — KV cache page size mismatch → patched `kv_cache_utils.py`
|
||||
3. `Some layers are not correctly initialized` — 91 missing compressor cache layers (alignment=576 wrong on Blackwell) → patched SWA, indexer, compressor cache specs
|
||||
4. `AttributeError: softmax_scale` — wrapper uses `self.scale` not `self.softmax_scale` → fixed
|
||||
5. 200 GiB KV cache for 512 tokens → reduced max_model_len to 256, patched cache specs to remove FlashMLA alignment
|
||||
6. NaN output (logprobs) → KV wasn't getting RoPE → added `_apply_rope_kv()`
|
||||
7. Shape mismatch in `apply_gptj_rope` → rewrote as inline 2D RoPE
|
||||
8. **Garbage/empty output** — the attention pipeline is fundamentally broken
|
||||
|
||||
### THE PLAN: Replace all pure PyTorch with CuTeDSL/NVFP4
|
||||
### What Actually Works (standalone B200 venv tests)
|
||||
|
||||
**Mike's directive: NO "optimize later" BS. Build the full thing with NVFP4/CuTeDSL.**
|
||||
Every single kernel works when tested individually. The problem is ONLY in the vLLM integration.
|
||||
|
||||
The attention Q×K and attn×V are activation×activation matmuls. They CAN be done in NVFP4:
|
||||
- Quantize Q and K to NVFP4 (4-bit activation + block scales)
|
||||
- Use CuTeDSL grouped GEMM for the sparse attention pattern
|
||||
- This is exactly what FlashMLA does with FP8 — we just use NVFP4 instead
|
||||
| Kernel | Test File | Result |
|
||||
|--------|-----------|--------|
|
||||
| CuTeDSL NVFP4 Linear | `test_full_layer_b200.py` | cosine 0.994+ ✅ |
|
||||
| CuTeDSL NVFP4 MoE | `layertest.py` | cosine 0.988 ✅ |
|
||||
| FP8 KV quantize/dequant | `test_kv_cache_b200.py` | cosine 0.9997 ✅ |
|
||||
| NVFP4 KV quantize/dequant | `test_kv_cache_b200.py` | cosine 0.9943 ✅ |
|
||||
| Paged KV cache read/write | `test_kv_cache_b200.py` | cosine 1.0 ✅ |
|
||||
| FP8 KV → full attention | `test_kv_cache_b200.py` | cosine 0.9997 ✅ |
|
||||
| CSA sparse attention (cr=4) | `test_sparse_attn_b200.py` | works, no NaN ✅ |
|
||||
| HCA sparse attention (cr=128) | `test_sparse_attn_b200.py` | works, no NaN ✅ |
|
||||
| Merged CSA+SWA attention | `test_sparse_attn_b200.py` | works, no NaN ✅ |
|
||||
| Full pipeline (all layer types) | `test_v4_attention_b200.py` | cosine 0.981-0.995 ✅ |
|
||||
| NVFP4 Q×K^T GEMM | `test_nvfp4_attn_gemm_b200.py` | cosine 0.86 ❌ (too lossy) |
|
||||
|
||||
#### Specific replacements needed:
|
||||
### Key Lessons (READ THESE OR REPEAT THE SAME MISTAKES)
|
||||
|
||||
1. **Attention (Q×K, attn×V)** → CuTeDSL NVFP4 GEMM
|
||||
- Quantize Q and K to NVFP4 per-head
|
||||
- Use `CuTeDSLNvfp4Linear` or raw CuTeDSL GEMM for the matmuls
|
||||
- Support both prefill (batched) and decode (single-token) paths
|
||||
- Handle CSA sparse gather pattern (only attend to top-k positions)
|
||||
1. **NVFP4 is NOT suitable for attention Q×K^T.** The per-element dot products are too sensitive. Cosine 0.86. Keep attention in BF16, use NVFP4 only for weight GEMMs.
|
||||
|
||||
2. **KV cache write** → NVFP4 quant + paged cache insert
|
||||
- The SWA cache uses `fp8_ds_mla` format: packed FP8 values + UE8M0 scales
|
||||
- Row width = 37376 bytes (not just head_dim=512)
|
||||
- Layout: [nope_dim FP8 values | rope_dim FP8 values | scale blocks]
|
||||
- Need to understand the exact fp8_ds_mla layout and replicate in CuTeDSL
|
||||
- OR: skip SWA cache entirely and use our own NVFP4 cache format
|
||||
2. **DeepSeek-V4 is NOT MLA.** It uses CSA (Compressed Sparse Attention) + HCA (Heavily Compressed Attention). vLLM misnames everything "MLA" internally — don't be confused by class names like `DeepseekV4MLAAttention`.
|
||||
|
||||
3. **RoPE** → CuTeDSL fused kernel
|
||||
- Currently pure PyTorch (works, but slow)
|
||||
- Could fuse with Q norm + KV quant into a single CuTeDSL kernel
|
||||
- Pattern: Q norm → RoPE → NVFP4 quant → all in one pass
|
||||
3. **The fp8_ds_mla format is FlashMLA-specific.** 584 bytes per token (448 NoPE FP8 + 128 RoPE FP8 + 8 scale). This is NOT a standard fp8 tensor. You can't just `view()` it as `[slot, 512]` uint8.
|
||||
|
||||
4. **CSA sparse gather** → CuTeDSL indexed access
|
||||
- Currently uses `torch.gather` (slow, not GPU-optimal)
|
||||
- CuTeDSL can do the gather + GEMM in one fused operation
|
||||
- This is the whole point of CSA — sparse KV access
|
||||
4. **The SWA cache, indexer cache, and compressor cache all use `alignment=576` for FlashMLA.** On Blackwell, this must be `None` (no FlashMLA). There are 4 separate classes that set this, and you must patch ALL of them.
|
||||
|
||||
5. **Compressor** → already Triton (works on SM100) ✅
|
||||
6. **Indexer** → already Triton (works on SM100) ✅
|
||||
7. **MHC** → already pure PyTorch ✅
|
||||
8. **MoE** → already CuTeDSL ✅
|
||||
5. **`DeepseekV4MultiHeadLatentAttentionWrapper` registers ITSELF (not the inner MLA attention) in `static_forward_context`.** The custom op `deepseek_v4_attention` looks up the wrapper. So `attention_impl` must be on the WRAPPER, and it must use `self.scale` (not `self.softmax_scale`).
|
||||
|
||||
### Config Issues (from config.json)
|
||||
6. **The Triton compressor and indexer DO work on Blackwell.** They're not the problem. The problem is that the Blackwell attention path doesn't integrate with them.
|
||||
|
||||
- `quant_method: modelopt` → vLLM uses ModelOpt's NVFP4 handler
|
||||
- Our CuTeDSL IS registered in `_POSSIBLE_NVFP4_KERNELS` (via register_cutedsl_kernel.py)
|
||||
- Added `VLLM_NVFP4_GEMM_BACKEND=cutedsl` env var to force it
|
||||
- `kv_cache_scheme: {"num_bits": 8, "type": "float"}` → FP8 KV cache → FlashMLA
|
||||
- Hard assertion `issubclass(get_attn_backend(), FlashMLASparseBackend)` — patched with `_is_blackwell` flag
|
||||
### THE PLAN: Build CuTeDSL Attention Backend
|
||||
|
||||
### Checkpoint Key Names (different from vLLM names!)
|
||||
**STOP. Do NOT touch the vLLM container. Build and test kernels on the B200 venv first.**
|
||||
|
||||
```
|
||||
q_a_proj, q_b_proj, kv_proj (NOT fused_wqa_wkv, wq_b)
|
||||
q_a_norm (NOT q_norm)
|
||||
attn_hc.fn/base/scale (MHC attention)
|
||||
ffn_hc.fn/base/scale (MHC FFN)
|
||||
compressor.kv_proj, compressor.gate_proj (CSA/HCA)
|
||||
compressor.position_bias
|
||||
sinks (attn_sink)
|
||||
```
|
||||
#### Step 1: KV Cache Write Kernel
|
||||
- BF16 KV → apply RoPE → fp8 quantize → write to paged cache
|
||||
- Test in `tests/test_kv_cache_write_b200.py`:
|
||||
- Write KV for N tokens, read it back, compare against BF16 reference
|
||||
- Must handle: slot mapping, block_size, fp8 per-token scale
|
||||
|
||||
### Compress Ratios (from config.json compress_ratios)
|
||||
#### Step 2: KV Cache Read Kernel
|
||||
- Paged cache → fp8 dequant → BF16 KV with RoPE
|
||||
- Test: write then read, cosine >= 0.99
|
||||
|
||||
```
|
||||
Layer 0: 128 (HCA) Layer 1: 128 (HCA)
|
||||
Layer 2: 4 (CSA) Layer 3: 128 (HCA)
|
||||
Layer 4: 4 (CSA) ...
|
||||
...alternating 4/128...
|
||||
Layer 60: 0 (SWA-only)
|
||||
```
|
||||
#### Step 3: BF16 Attention Kernel
|
||||
- Q (with RoPE) × K^T → causal mask → softmax → attn × V
|
||||
- Keep in BF16 (NVFP4 too lossy for attention scores)
|
||||
- Handle CSA sparse (gather top-k positions from compressed cache)
|
||||
- Handle HCA sparse (gather from 1/128 positions)
|
||||
- Handle SWA (sliding window, full causal within window)
|
||||
- Test: compare against PyTorch SDPA, cosine >= 0.99
|
||||
|
||||
#### Step 4: Full Pipeline Integration
|
||||
- KV cache read → attention → inverse RoPE → o_a BMM → o_b NVFP4 projection
|
||||
- Wire CSA/HCA/SWA with sink weight merge
|
||||
- Test: compare full pipeline against BF16 reference, cosine >= 0.98
|
||||
- Test: run through ALL 61 layers, verify logits are reasonable (std between 0.5 and 50)
|
||||
|
||||
#### Step 5: vLLM Attention Backend
|
||||
- Create a proper `AttentionBackend` subclass (e.g., `CuTeDSLBlackwellBackend`)
|
||||
- Override `DeepseekSparseSWABackend` on Blackwell
|
||||
- Handle metadata, slot mapping, cache format properly
|
||||
- ONLY THEN test in the container
|
||||
|
||||
#### Step 6: Test in Container
|
||||
- Build container with the new backend
|
||||
- Test with real prompts
|
||||
- If output is garbage, DO NOT declare success. Fix it.
|
||||
|
||||
### Architecture: CSA + HCA + mHC (NOT MLA!)
|
||||
|
||||
- **CSA (Compress Ratio 4)**: Compressed Sparse Attention — KV compressed 4x with overlap (coff=2)
|
||||
- **HCA (Compress Ratio 128)**: Heavily Compressed Attention — KV compressed 128x
|
||||
- **mHC**: Manifold-Constrained Hyper-Connections — replaces standard residual connections
|
||||
- **SWA**: Sliding Window Attention (compress_ratio=0, last layer only)
|
||||
- **CSA (Compress Ratio 4)**: Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k.
|
||||
- **HCA (Compress Ratio 128)**: Heavily Compressed Attention — KV compressed 128x.
|
||||
- **mHC**: Manifold-Constrained Hyper-Connections — replaces standard residual connections.
|
||||
- **SWA**: Sliding Window Attention — local window (compress_ratio=0, last layer only)
|
||||
- **KV latent**: (T, 512) shared across all 128 heads. After kv_norm, it's a single 512-dim vector per token.
|
||||
- **Sink weights**: Merge sparse + SWA attention outputs. `sigmoid(sinks)` controls the mix.
|
||||
|
||||
### Files
|
||||
### B200 Server Details
|
||||
|
||||
- **Kernel**: `cutedsl/csa_attention.py` — CSA/HCA attention (currently SDPA, needs CuTeDSL)
|
||||
- **vLLM patch**: `vllm/patches/deepseek_v4_attention.py` — `_attention_impl_blackwell()`
|
||||
- **vLLM patch**: `vllm/patches/layers/csa_attention.py` — `fused_qnorm_rope_kv_insert_py()`, `full_sdpa_attention()`
|
||||
- **Standalone tests**: `tests/test_full_layer_b200.py`, `tests/test_csa_attention_b200.py`, `tests/test_model_forward_b200.py`
|
||||
- **CuTeDSL NVFP4 linear**: `cutedsl/nvfp4_linear.py` — CuTeDSLNvfp4Linear runner
|
||||
- **CuTeDSL bridge**: `cutedsl/bridge.py` — quantize_activation_nvfp4, NVFP4 GEMM wrappers
|
||||
```
|
||||
Server: root@45.76.247.107 (password: <B200_PASSWORD>)
|
||||
Repo: /root/nvfp4-megamoe-kernel (branch: proper-nvfp4-integration)
|
||||
Model: /root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4
|
||||
Venv: /root/nvfp4-megamoe-kernel/tests/venv/
|
||||
Test command: PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/<test>.py
|
||||
Container: docker compose up (takes ~14 min to build, ~10 min to warm up)
|
||||
```
|
||||
|
||||
### PLEASE LISTEN
|
||||
|
||||
If you're the next session reading this: BUILD THE KERNELS. Test them standalone. Don't waste time patching vLLM. The container will still be there when you're ready. The kernels won't build themselves.
|
||||
|
||||
**Mike said it best: "The only way to do this is to do our own kernels."**
|
||||
|
||||
Just make the fucking kernel.
|
||||
|
||||
225
README.md
225
README.md
@@ -2,6 +2,34 @@
|
||||
|
||||
Full NVFP4 inference pipeline for DeepSeek-V4 on NVIDIA Blackwell (SM100). The entire model — MoE experts, shared experts, attention projections, and attention compute — runs in native NVFP4 with zero dequantization overhead.
|
||||
|
||||
## ⚠️ READ THIS FIRST — THE #1 RULE
|
||||
|
||||
**YOU MUST BUILD YOUR OWN KERNELS. ALL OF THEM. DO NOT PATCH vLLM.**
|
||||
|
||||
Mike was right — we need our own kernels. Not just for the NVFP4 GEMMs, but for the **ENTIRE attention pipeline**. The current approach of patching individual vLLM functions is a house of cards. Every patch leads to another crash, every workaround reveals three more broken things. FlashMLA, fp8_ds_mla, the fused C++ kernels, the Triton compressor, the indexer — they're all deeply coupled. You cannot swap one piece and expect the rest to work.
|
||||
|
||||
**THE ONLY PATH FORWARD:**
|
||||
1. Build CuTeDSL kernels for EVERYTHING — attention, KV cache, RoPE, the whole stack
|
||||
2. Test each kernel standalone on the B200 venv BEFORE touching the container
|
||||
3. Wire them together into a proper vLLM attention backend
|
||||
4. THEN and ONLY THEN test in the container
|
||||
|
||||
**DO NOT:**
|
||||
- ❌ Try to patch vLLM's FlashMLA code to "work" on Blackwell
|
||||
- ❌ Use pure PyTorch as a "temporary workaround" — it produces garbage
|
||||
- ❌ Skip the KV cache write and hope for the best
|
||||
- ❌ Assume you can mix our kernels with vLLM's existing attention backend
|
||||
- ❌ Touch the container until ALL kernels pass standalone tests
|
||||
|
||||
**DO:**
|
||||
- ✅ Build CuTeDSL kernels in `cutedsl/`
|
||||
- ✅ Test each one in `tests/` on the B200 venv
|
||||
- ✅ Compare against BF16 reference (cosine >= 0.98 or it's broken)
|
||||
- ✅ Wire them into a proper attention backend class
|
||||
- ✅ Only test in the container once everything passes standalone
|
||||
|
||||
---
|
||||
|
||||
## What This Is
|
||||
|
||||
A native NVFP4 inference stack for DeepSeek-V4:
|
||||
@@ -11,7 +39,7 @@ A native NVFP4 inference stack for DeepSeek-V4:
|
||||
BF16 input → quantize to NVFP4
|
||||
L1 GEMM: NVFP4 × NVFP4 → BF16 (gate + up)
|
||||
SiLU(gate) * up → BF16 (only nonlinear — can't avoid BF16 here)
|
||||
Re-quantize → NVFP4
|
||||
Re-quantize to NVFP4
|
||||
L2 GEMM: NVFP4 × NVFP4 → BF16 (down_proj)
|
||||
Scatter with routing weights → BF16 output
|
||||
```
|
||||
@@ -19,23 +47,19 @@ BF16 input → quantize to NVFP4
|
||||
**Attention Projections** — CuTeDSL NVFP4 GEMM ✅:
|
||||
- `q_a_proj`, `q_b_proj`, `kv_proj`, `wo_b_proj` — native NVFP4, cosine 0.995 vs BF16
|
||||
- `wo_a` — BF16 BMM (o_a_proj weights are BF16 in checkpoint)
|
||||
- `compressor.kv_proj`, `compressor.gate_proj` — native NVFP4, cosine 0.995 vs BF16
|
||||
- All verified with `tests/test_full_layer_b200.py`
|
||||
|
||||
**Shared Experts** — CuTeDSL NVFP4 GEMM ✅:
|
||||
- `gate_up_proj`, `down_proj` — native NVFP4, cosine 0.990 vs BF16
|
||||
|
||||
**Attention Compute** — **NEEDS CuTeDSL NVFP4** 🔧:
|
||||
- Currently using pure PyTorch SDPA as a TEMPORARY workaround
|
||||
- Q×K and attn×V are activation×activation matmuls that CAN be NVFP4
|
||||
- FlashMLA (vLLM's CUDA kernel) is broken on Blackwell
|
||||
- **Plan: CuTeDSL NVFP4 attention kernel** — quantize Q/K to NVFP4, use CuTeDSL GEMMs
|
||||
**Attention Compute** — 🔧 NEEDS CuTeDSL:
|
||||
- Pure PyTorch SDPA produces garbage in the container
|
||||
- FlashMLA is broken on Blackwell
|
||||
- Must build CuTeDSL kernels for Q×K, attn×V, KV cache, RoPE
|
||||
|
||||
**KV Cache Write** — **NEEDS CuTeDSL NVFP4** 🔧:
|
||||
- The SWA KV cache uses `fp8_ds_mla` packed format (37376 bytes per slot, not 512)
|
||||
- C++ kernel `fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert` is broken on Blackwell
|
||||
- Currently skipped in Blackwell path (works for prefill, breaks decode)
|
||||
- **Plan: NVFP4 quant + paged cache insert in CuTeDSL**
|
||||
**KV Cache** — 🔧 NEEDS CuTeDSL:
|
||||
- The fp8_ds_mla format is FlashMLA-specific (584 bytes per token)
|
||||
- Must build our own NVFP4 KV cache with our own format
|
||||
|
||||
## Architecture: DeepSeek-V4-Pro
|
||||
|
||||
@@ -52,85 +76,55 @@ Layer 0: 128 (HCA) Layer 1: 128 (HCA) Layer 2: 4 (CSA) Layer 3: 128 (HCA)
|
||||
Layer 4: 4 (CSA) ...alternating 4/128... Layer 60: 0 (SWA)
|
||||
```
|
||||
|
||||
**Checkpoint Key Names** (different from vLLM's internal names):
|
||||
```
|
||||
q_a_proj, q_b_proj, kv_proj (NOT fused_wqa_wkv, wq_b)
|
||||
q_a_norm (NOT q_norm)
|
||||
attn_hc.fn/base/scale (MHC attention)
|
||||
ffn_hc.fn/base/scale (MHC FFN)
|
||||
compressor.kv_proj, compressor.gate_proj (CSA/HCA compressor)
|
||||
compressor.position_bias
|
||||
sinks (attn_sink)
|
||||
```
|
||||
## Current Status
|
||||
|
||||
## Current Status: Attention + KV Cache Need CuTeDSL 🔧
|
||||
### ✅ Working (verified on B200 standalone tests)
|
||||
|
||||
**What works (verified on B200):**
|
||||
- CuTeDSL NVFP4 linear kernels: cosine 0.989–0.999 vs BF16 ✅
|
||||
- CuTeDSL NVFP4 MoE: cosine 0.988 ✅
|
||||
- Full attention path with PyTorch SDPA: cosine 0.988 vs BF16 ✅
|
||||
- MHC, RMS norm, RoPE (BF16), wo_a BMM, shared experts ✅
|
||||
- Compressor + indexer (Triton, works on SM100) ✅
|
||||
| Component | Test | Cosine vs BF16 |
|
||||
|-----------|------|----------------|
|
||||
| CuTeDSL NVFP4 Linear (q_a, kv, q_b, wo_b) | `test_full_layer_b200.py` | 0.994+ |
|
||||
| CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) | `layertest.py` | 0.988 |
|
||||
| FP8 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9997 |
|
||||
| NVFP4 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9943 |
|
||||
| Paged KV cache read/write | `test_kv_cache_b200.py` | 1.0 |
|
||||
| FP8 KV → full attention | `test_kv_cache_b200.py` | 0.9997 |
|
||||
| CSA sparse attention (cr=4) | `test_sparse_attn_b200.py` | works, no NaN |
|
||||
| HCA sparse attention (cr=128) | `test_sparse_attn_b200.py` | works, no NaN |
|
||||
| Merged CSA+SWA attention | `test_sparse_attn_b200.py` | works, no NaN |
|
||||
| Full attention pipeline (all layer types) | `test_v4_attention_b200.py` | 0.981–0.995 |
|
||||
| RoPE (GPT-J) | `test_v4_attention_b200.py` | works |
|
||||
| Inverse RoPE + o_a BMM | `test_v4_attention_b200.py` | works |
|
||||
|
||||
**What's broken:**
|
||||
- FlashMLA CUDA kernel → garbage on Blackwell → model outputs immediate EOS
|
||||
- `fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert` C++ kernel → crashes on Blackwell
|
||||
- Pure PyTorch SDPA is a TEMPORARY workaround — must replace with CuTeDSL NVFP4
|
||||
### 🔧 Needs CuTeDSL Kernels
|
||||
|
||||
**What needs to be built:**
|
||||
1. **Attention Q×K^T** — BF16 matmul works standalone, but NVFP4 GEMM too lossy (cosine 0.86). Keep Q×K in BF16.
|
||||
2. **KV Cache Write** — need CuTeDSL kernel that does: RoPE → fp8 quant → paged cache insert
|
||||
3. **KV Cache Read** — need CuTeDSL kernel that does: paged cache read → fp8 dequant
|
||||
4. **Fused Q-norm + RoPE** — currently pure PyTorch (works, slow)
|
||||
5. **Fused inverse RoPE + o_a BMM** — currently pure PyTorch (works)
|
||||
|
||||
### 1. CuTeDSL NVFP4 Attention Kernel
|
||||
- Quantize Q and K to NVFP4 per-head
|
||||
- Use CuTeDSL GEMM for Q×K and attn×V
|
||||
- Support prefill (batched) and decode (single-token) paths
|
||||
- Handle CSA sparse gather (attend to top-k positions only)
|
||||
- This is exactly what FlashMLA does with FP8 — we just use NVFP4 instead
|
||||
- **Test first**: build standalone test in `tests/` with real weights
|
||||
### ❌ Does NOT Work
|
||||
|
||||
### 2. CuTeDSL NVFP4 KV Cache Insert
|
||||
- Replace C++ `fused_deepseek_v4_qnorm_rope_kv_rope_quant_insert`
|
||||
- Per-head RMS norm on Q + GPT-J RoPE on Q + RoPE on KV + NVFP4 quant + paged cache write
|
||||
- The SWA cache uses `fp8_ds_mla` packed format: row width = 37376 bytes
|
||||
- Layout: [nope_dim FP8 values | rope_dim FP8 values | UE8M0 scale blocks]
|
||||
- NOT just [head_dim] — it's a packed FP8 format with interleaved scales
|
||||
- **Option A**: Understand and write the fp8_ds_mla format from CuTeDSL
|
||||
- **Option B**: Use our own NVFP4 cache format (simpler, more efficient, but diverges from vLLM)
|
||||
- **NVFP4 Q×K^T GEMM** — cosine 0.86, too lossy for attention scores. Keep attention in BF16.
|
||||
- **Patching vLLM's FlashMLA path** — house of cards. Don't do it.
|
||||
- **Pure PyTorch SDPA in the container** — produces garbage because the KV cache isn't written and the pipeline is broken.
|
||||
|
||||
### 3. CuTeDSL Fused RoPE + Norm Kernel
|
||||
- Currently pure PyTorch (works, but slow)
|
||||
- Fuse: Q norm → RoPE → NVFP4 quant → all in one pass
|
||||
- Same for KV side: RoPE → NVFP4 quant → cache write
|
||||
## Container Status
|
||||
|
||||
### 4. CuTeDSL CSA Sparse Gather
|
||||
- Currently `torch.gather` (slow, not GPU-optimal)
|
||||
- CuTeDSL can do the gather + GEMM in one fused operation
|
||||
- The whole point of CSA is sparse KV access — we should do it right
|
||||
|
||||
## vLLM Integration
|
||||
|
||||
The Blackwell detection and dispatch is in `vllm/patches/deepseek_v4_attention.py`:
|
||||
- `attention_impl()` detects SM100+ → `_attention_impl_blackwell()`
|
||||
- Currently uses pure PyTorch (SDPA) — **must replace with CuTeDSL**
|
||||
- The dispatch is INSIDE the `torch.ops.vllm.deepseek_v4_attention` custom op boundary (important for torch.compile)
|
||||
|
||||
**Config issues:**
|
||||
- `quant_method: modelopt` → vLLM uses ModelOpt's NVFP4 handler
|
||||
- Our CuTeDSL IS registered (via `register_cutedsl_kernel.py`) and forced with `VLLM_NVFP4_GEMM_BACKEND=cutedsl`
|
||||
- FlashMLA hard assertion in `DeepseekV4MLAAttention.__init__` — patched with `_is_blackwell` flag
|
||||
- `kv_cache_scheme: {"num_bits": 8, "type": "float"}` → FP8 KV cache → FlashMLA (broken on Blackwell)
|
||||
|
||||
**Key discovery: warmup gs is irrelevant.** CuTeDSL runner recomputes activation global scale per-call internally. Changing it 10x has zero effect on output (cosine 0.9993). The `input_scale` from the checkpoint is NOT the activation global scale — it's a calibration constant.
|
||||
The container builds and starts successfully. The server accepts requests and generates tokens. But the output is empty/garbage because the Blackwell attention path is broken. Multiple patches were applied to get this far (KV cache page sizes, FlashMLA alignment, softmax_scale, compressor cache), but the fundamental problem remains: **you cannot half-ass the attention pipeline**.
|
||||
|
||||
## Test Files
|
||||
|
||||
| Test | What it does | Status |
|
||||
|------|-------------|--------|
|
||||
| `tests/test_full_layer_b200.py` | All NVFP4 projections vs BF16 (layer 0) | ✅ All pass (0.989–0.999) |
|
||||
| `tests/test_model_forward_b200.py` | Warmup gs vs dynamic gs diagnostic | ✅ Warmup gs irrelevant |
|
||||
| `tests/test_csa_attention_b200.py` | Full attention path with SDPA | ✅ cosine 0.988 |
|
||||
| `tests/layertest.py` | MoE layer test | ✅ cosine 0.988 |
|
||||
| `tests/test_full_layer_b200.py` | All NVFP4 projections vs BF16 | ✅ 0.994+ |
|
||||
| `tests/layertest.py` | MoE layer test | ✅ 0.988 |
|
||||
| `tests/cudagraph_test.py` | CUDAGraph compatibility | ✅ PASS |
|
||||
| `tests/test_shared_expert.py` | Shared expert standalone | ✅ cosine 0.990 |
|
||||
| `tests/test_csa_attention_b200.py` | Full attention with SDPA | ✅ 0.988 |
|
||||
| `tests/test_v4_attention_b200.py` | All 3 layer types (SWA, C128A, C4A) | ✅ 0.981-0.995 |
|
||||
| `tests/test_kv_cache_b200.py` | FP8/NVFP4 KV cache + paged cache | ✅ 0.9997 |
|
||||
| `tests/test_sparse_attn_b200.py` | CSA/HCA sparse + SWA merged | ✅ works |
|
||||
| `tests/test_nvfp4_attn_gemm_b200.py` | NVFP4 Q×K^T GEMM | ❌ 0.86 (too lossy) |
|
||||
|
||||
## Project Structure
|
||||
|
||||
@@ -140,51 +134,64 @@ nvfp4-megamoe-kernel/
|
||||
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
|
||||
│ ├── nvfp4_linear.py # CuTeDSLNvfp4Linear — NVFP4 GEMM runner
|
||||
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
|
||||
│ ├── shared_expert_pipeline.py # Shared expert pipeline (1-expert MoE variant)
|
||||
│ ├── csa_attention.py # CSA/HCA attention (currently SDPA, needs CuTeDSL)
|
||||
│ ├── custom_ops.py # torch.autograd wrappers for compile boundary
|
||||
│ ├── shared_expert_pipeline.py # Shared expert pipeline
|
||||
│ ├── csa_attention.py # CSA/HCA attention (BF16 SDPA — needs CuTeDSL)
|
||||
│ ├── custom_ops.py # torch.autograd wrappers
|
||||
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel
|
||||
├── vllm/ # vLLM integration
|
||||
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner — cudagraph-safe MoE kernel
|
||||
│ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod — vLLM quant method
|
||||
│ ├── kernels/linear/nvfp4/cutedsl.py # CuTeDSLNvFp4LinearKernel — vLLM kernel registration
|
||||
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner
|
||||
│ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod
|
||||
│ ├── kernels/linear/nvfp4/cutedsl.py # vLLM kernel registration
|
||||
│ └── patches/
|
||||
│ ├── deepseek_v4.py # Model patch (NVFP4 native, MHC, MoE)
|
||||
│ ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch)
|
||||
│ ├── layers/
|
||||
│ │ ├── mhc.py # MHC pure PyTorch (replaces TileLang)
|
||||
│ │ ├── csa_attention.py # CSA attention (TEMPORARY — needs CuTeDSL)
|
||||
│ │ └── deepseek_compressor.py # Compressor (Triton, works on SM100)
|
||||
│ └── fused_moe/experts/cutedsl_moe.py # MoE CuTeDSL integration
|
||||
├── tests/ # Standalone tests (run on B200 outside container)
|
||||
│ ├── patch_kv_cache_utils.py # KV cache page size fix
|
||||
│ ├── patch_swa_cache.py # SWA cache alignment fix
|
||||
│ ├── patch_indexer_cache.py # Indexer cache alignment fix
|
||||
│ ├── patch_compressor_cache.py # Compressor cache alignment fix
|
||||
│ └── layers/
|
||||
│ ├── csa_attention.py # BF16 SDPA (TEMPORARY — needs CuTeDSL)
|
||||
│ └── ...
|
||||
├── tests/ # Standalone tests (run on B200 venv)
|
||||
└── Dockerfile # Container build
|
||||
```
|
||||
|
||||
## Plan
|
||||
|
||||
### Phase 1: MoE Kernel ✅ DONE
|
||||
- CuTeDSL ScaledGroupedGemmKernel with NVFP4
|
||||
- Full pipeline: cosine 0.988, cudagraph-safe
|
||||
|
||||
### Phase 2: NVFP4 Linear Kernels ✅ DONE
|
||||
- All attention projections: cosine 0.995
|
||||
- Shared experts: cosine 0.990
|
||||
- Compressor projections: cosine 0.995
|
||||
### Phase 3: vLLM Integration ✅ DONE (NVFP4 linear + MoE working)
|
||||
|
||||
### Phase 3: vLLM Integration ✅ DONE (with PyTorch fallback)
|
||||
- CuTeDSL kernels registered and working for all NVFP4 linear layers
|
||||
- Blackwell dispatch in attention_impl
|
||||
- MHC pure PyTorch
|
||||
- MoE CuTeDSL
|
||||
### Phase 4: CuTeDSL Attention Backend 🔧 NEXT — BUILD THE KERNELS
|
||||
|
||||
### Phase 4: CuTeDSL NVFP4 Attention 🔧 NEXT
|
||||
- Replace pure PyTorch SDPA with CuTeDSL NVFP4 GEMMs for Q×K and attn×V
|
||||
- NVFP4 KV cache insert (replace C++ kernel)
|
||||
- Fused RoPE + norm + quant kernel
|
||||
- CSA sparse gather in CuTeDSL
|
||||
- **Test each component standalone before integrating into vLLM**
|
||||
**STOP. READ THIS.**
|
||||
|
||||
Do NOT touch the vLLM container until ALL of these kernels pass standalone tests on the B200 venv. The container is a 14-minute build cycle. The venv gives you instant feedback. TEST FIRST.
|
||||
|
||||
**Kernels to build (in order):**
|
||||
|
||||
1. **KV Cache Write**: BF16 KV → apply RoPE → quantize to fp8 → write to paged cache
|
||||
- Test: compare against BF16 reference (cosine >= 0.98 after dequant)
|
||||
|
||||
2. **KV Cache Read**: paged cache → dequant fp8 → BF16 KV with RoPE
|
||||
- Test: write then read back, cosine >= 0.99
|
||||
|
||||
3. **BF16 Attention**: Q (with RoPE) × K^T → softmax → attn × V
|
||||
- Keep this in BF16 (NVFP4 is too lossy for attention scores)
|
||||
- Handle CSA sparse gather (attend to top-k indexed positions)
|
||||
- Handle HCA sparse gather (attend to 1/128 positions)
|
||||
- Handle SWA (sliding window, full causal within window)
|
||||
- Test: compare against PyTorch SDPA reference (cosine >= 0.99)
|
||||
|
||||
4. **Full Attention Pipeline**: KV cache read → attention → inverse RoPE → o_a BMM
|
||||
- Wire everything together
|
||||
- Test: compare against BF16 reference (cosine >= 0.98)
|
||||
|
||||
5. **vLLM Backend**: Wrap as a proper AttentionBackend subclass
|
||||
- Override `DeepseekSparseSWABackend` on Blackwell
|
||||
- Handle the metadata, slot mapping, cache format
|
||||
- ONLY THEN test in the container
|
||||
|
||||
### Phase 5: Production
|
||||
- End-to-end benchmarking
|
||||
- Optimize tile sizes for occupancy
|
||||
- Clean up old C++ kernel code
|
||||
- Optimize tile sizes
|
||||
- Clean up
|
||||
|
||||
Reference in New Issue
Block a user