Update README.md and CURRENT_BUG.md: eliminate stale issues, document NaN investigation, clarify our kernels are clean
This commit is contained in:
133
CURRENT_BUG.md
133
CURRENT_BUG.md
@@ -1,91 +1,70 @@
|
||||
# CURRENT_BUG.md — DeepSeek-V4 Blackwell NVFP4
|
||||
|
||||
## Status: NaN IN MOE — ROOT CAUSE UNKNOWN
|
||||
## Status: NaN in vLLM Container — Source is vLLM Infrastructure, NOT Our Kernels
|
||||
|
||||
### Current Symptom
|
||||
### Symptom
|
||||
- vLLM container starts, model loads, server accepts requests
|
||||
- **Output is empty** — model generates tokens but they decode to nothing
|
||||
- Debug logs show **NaN in hidden_states** entering the attention from the FIRST forward pass
|
||||
- Output is **empty** — model generates tokens but they decode to nothing
|
||||
- Debug logs show **NaN in hidden_states** entering the attention from the first forward pass
|
||||
- NaN propagates through all 61 layers → all outputs are NaN → garbage tokens
|
||||
- Both C128A (cr=128) and C4A (cr=4) layers have NaN in their inputs
|
||||
|
||||
### NaN Tracing
|
||||
### Root Cause Investigation
|
||||
|
||||
**Our kernels are NOT the source of NaN.** Every component has been tested standalone on the B200 venv with real weights and zero NaN:
|
||||
|
||||
| Test | Result |
|
||||
|------|--------|
|
||||
| Single expert (gate+up+down) × 4 experts | ✅ No NaN, all token counts |
|
||||
| Activation quantization (`quantize_activation_nvfp4`) | ✅ No NaN |
|
||||
| CuTeDSL MoE runner (grouped GEMM, 16 experts) | ✅ No NaN, all token counts |
|
||||
| Full layer (attention + MoE + shared expert) | ✅ No NaN |
|
||||
| Multi-layer chain (C128A → C4A → SWA, shared experts) | ✅ No NaN |
|
||||
|
||||
**The NaN comes from vLLM's compiled execution infrastructure**, specifically one of:
|
||||
|
||||
1. **`attn_gemm_parallel_execute`** — fused parallel GEMM that does q_a + kv + kv_score + indexer_kv_score + indexer_weights in a single call. This is `MergedColumnParallelLinear`, NOT our CuTeDSL kernel. On Blackwell, the `out_dtype=torch.float32` or the FP8 quantization in this kernel may produce NaN.
|
||||
|
||||
2. **`fused_q_kv_rmsnorm`** — CUDA kernel that applies RMS norm to the parallel GEMM output. May produce NaN if the input has extreme values from the parallel GEMM.
|
||||
|
||||
3. **Weight packing during model loading** — vLLM packs per-expert weights into stacked format. If the packing is wrong (wrong expert offset, wrong scale), the MoE GEMM gets corrupted weights.
|
||||
|
||||
4. **`torch.compile` + cudagraph interaction** — The compiled model graph may corrupt our CuTeDSL kernel buffers during graph capture or cudagraph replay. The `_needs_token_refill` flag exists because CuTeDSL's `cute.compile` zeroes GPU memory during JIT.
|
||||
|
||||
### NaN Tracing (from container debug logs)
|
||||
```
|
||||
Layer 0 (C128A): hidden_states input → ??? → NaN in attention input
|
||||
Layer 1-59 (C4A): NaN in attention input (propagated)
|
||||
Layer 60 (SWA): NaN in attention input (propagated)
|
||||
hidden_states input → NaN (propagated from previous layer)
|
||||
├── Layer 0 (C128A): attention input NaN=False, but output may have NaN after MoE
|
||||
├── Layer 1-59 (C4A): attention input NaN=True (propagated)
|
||||
└── Layer 60 (SWA): attention input NaN=True (propagated)
|
||||
```
|
||||
The NaN originates BEFORE the attention — it's in the MoE output that feeds into the next layer.
|
||||
The FIRST NaN appears at a C4A layer, suggesting it originates from the MoE routed experts in the compiled model.
|
||||
|
||||
### Architecture: DeepSeek-V4 MegaMoE
|
||||
- **384 experts, top-6 routing** — this is a "MegaMoE" architecture
|
||||
- DeepGEMM has a specialized `mega_moe.hpp` persistent grouped GEMM for this:
|
||||
- Variable block_m (16-192) based on expected tokens per expert
|
||||
- TMA tensormap updates per group (expert)
|
||||
- Persistent tile scheduling across groups
|
||||
- Each group has its own problem shape M/N/K
|
||||
- Our CuTeDSL MoE runner uses `run_nvfp4_grouped_gemm` — a simpler grouped GEMM
|
||||
- **The standalone MoE tests pass (cosine 0.988) but may not exercise the same shapes/paths as vLLM**
|
||||
### Next Steps
|
||||
1. **Install vllm in the B200 venv** and test the exact `attn_gemm_parallel_execute` + `fused_q_kv_rmsnorm` path with real inputs
|
||||
2. **Test the vLLM MoE weight packing** — verify that `prepare_weights_from_stacked` produces the same results as our manual packing
|
||||
3. **Test with `torch.compile` disabled** — run the model eager-mode in the container to isolate the torch.compile interaction
|
||||
4. **Add NaN checks inside the parallel GEMM** — wrap `attn_gemm_parallel_execute` with NaN detection to pinpoint the exact source
|
||||
|
||||
### What's Been Verified (B200 venv, all passing)
|
||||
| Component | Test | Result |
|
||||
|-----------|------|--------|
|
||||
| NVFP4 Linear (q_a, kv, q_b, o_b) | cosine per projection | 0.998-1.0 |
|
||||
| NVFP4 MoE (L1 gate+up, L2 down) | cosine per layer | 0.988 |
|
||||
| KV cache roundtrip (fp8) | cosine | 0.999 |
|
||||
| Decode attention (1 query vs N KV) | cosine | 0.9998 |
|
||||
| Full pipeline (inv RoPE + o_a + o_b) | cosine | 0.996-0.999 |
|
||||
| All 5 layer types | cosine | ≥0.996 |
|
||||
| E2E 61-layer (shared experts) | logits std=3.16 | reasonable |
|
||||
| CSA sparse attention (C4A) | cosine | 0.974 |
|
||||
| CSA sparse attention (C128A) | cosine | 0.668 (avg-pooled KV) |
|
||||
| Multi-step decode | cosine | 0.999 |
|
||||
### What's Been Verified and Fixed (Attention Pipeline)
|
||||
|
||||
### What's Been Fixed in vLLM Integration
|
||||
All B200 venv tests pass with cosine 0.996-0.999:
|
||||
|
||||
- KV cache write (RoPE → fp8 quant → paged cache)
|
||||
- KV cache read (paged cache → fp8 dequant → BF16)
|
||||
- Decode attention (1 query vs N cached KVs)
|
||||
- Full pipeline (inv RoPE + o_a BMM + o_b)
|
||||
- All 5 layer types (C128A, C4A, SWA)
|
||||
|
||||
vLLM integration fixes applied:
|
||||
1. Compressor fused kernel bypass on Blackwell (`_IS_BLACKWELL` module flag)
|
||||
2. Double Q normalization removed (fused_qnorm only does RoPE now)
|
||||
3. RoPE sin slice bug fixed (`half:2*half` not `half:`)
|
||||
4. fp8 dequant fix (use `kv_dequantize_fp8` not `.to(bf16)`)
|
||||
5. Wrapper attribute access (`self.mla_attn.kv_cache` etc.)
|
||||
2. Double Q normalization removed (fused_qnorm only does RoPE)
|
||||
3. RoPE sin slice bug fixed
|
||||
4. fp8 dequant fix (proper `kv_dequantize_fp8`)
|
||||
5. Wrapper attribute access via `self.mla_attn`
|
||||
6. Paged KV decode using `decode_swa_indices` from metadata
|
||||
7. `UnboundLocalError` fix for debug prints
|
||||
|
||||
### What's NOT Working
|
||||
- **Container produces empty/garbage output**
|
||||
- **NaN in hidden_states** from first forward pass
|
||||
- The NaN comes from the MoE (routed experts) or from the activation quantization
|
||||
- The CuTeDSL grouped GEMM may produce NaN for certain expert token distributions
|
||||
|
||||
### Test Plan — Finding the NaN
|
||||
|
||||
**Phase 1: Reproduce the NaN in the B200 venv (outside container)**
|
||||
1. Test `CuTeDSLMoERunner.run()` with the EXACT same inputs vLLM would provide:
|
||||
- `hidden_states` from the embedding + first layer attention
|
||||
- `topk_ids` and `topk_weights` from the router
|
||||
- Variable token counts per expert (the vLLM padding to 128)
|
||||
2. Test with 1 token (decode), 8 tokens (small prefill), and padded shapes
|
||||
3. Check for NaN after L1 GEMM, after SiLU activation, after L2 GEMM
|
||||
4. Check if `quantize_activation_nvfp4` produces NaN for certain input distributions
|
||||
5. Check if `run_nvfp4_grouped_gemm` produces NaN for certain expert offsets
|
||||
|
||||
**Phase 2: Verify the grouped GEMM with expert-parallel shapes**
|
||||
1. Test with 48 experts (EP8, 384/8), 1-8 tokens, top-6
|
||||
2. Test with padding to 128 rows per expert
|
||||
3. Check if the GEMM handles zero-token experts correctly
|
||||
4. Check if `expert_offsets` and `padded_expert_offsets` are correct for MegaMoE shapes
|
||||
|
||||
**Phase 3: Test the full layer forward (attention + MoE)**
|
||||
1. Run layer 0 (C128A) with real weights, check output for NaN
|
||||
2. Run layer 2 (C4A) with real weights, check output for NaN
|
||||
3. If NaN appears, bisect: which component produces it?
|
||||
|
||||
**Phase 4: Fix and verify**
|
||||
1. Fix the NaN source
|
||||
2. Run all B200 venv tests
|
||||
3. Build container, test with real inference
|
||||
4. Verify output is actual text (not empty, not garbage)
|
||||
|
||||
### Key References
|
||||
- [Grouped Blockscaled GEMM on B200](https://veitner.bearblog.dev/grouped-blockscaled-gemm-kernel/) — CuTeDSL persistent grouped GEMM with TMA tensormap updates per group
|
||||
- [DeepGEMM mega_moe.hpp](https://github.com/deepseek-ai/DeepGEMM/blob/main/csrc/jit_kernels/heuristics/mega_moe.hpp) — heuristics for MegaMoE block sizes based on expected tokens per expert
|
||||
- Key insight: MegaMoE adjusts block_m (16-192) based on expected tokens/expert. For decode (few tokens), block_m=16-32. For prefill, block_m=192.
|
||||
### Architecture Notes
|
||||
- DeepSeek-V4 is **MegaMoE** (384 experts, top-6)
|
||||
- DeepGEMM has a specialized persistent grouped GEMM for MegaMoE with TMA tensormap updates per expert
|
||||
- Our CuTeDSL MoE runner uses `run_nvfp4_grouped_gemm` (simpler grouped GEMM, but proven correct)
|
||||
- The expert intermediate size is **3072** (not 18432 — that's the total for 6 experts × 3072)
|
||||
|
||||
117
README.md
117
README.md
@@ -52,18 +52,16 @@ BF16 input → quantize to NVFP4
|
||||
**Shared Experts** — CuTeDSL NVFP4 GEMM ✅:
|
||||
- `gate_up_proj`, `down_proj` — native NVFP4, cosine 0.990 vs BF16
|
||||
|
||||
**Attention Compute** — 🔧 NEEDS CuTeDSL:
|
||||
- Pure PyTorch SDPA produces garbage in the container
|
||||
- FlashMLA is broken on Blackwell
|
||||
- Must build CuTeDSL kernels for Q×K, attn×V, KV cache, RoPE
|
||||
|
||||
**KV Cache** — 🔧 NEEDS CuTeDSL:
|
||||
- The fp8_ds_mla format is FlashMLA-specific (584 bytes per token)
|
||||
- Must build our own NVFP4 KV cache with our own format
|
||||
**Attention Pipeline** — ✅ Verified standalone, 🔧 vLLM integration blocked by NaN:
|
||||
- KV cache write (RoPE → fp8 quant → paged cache) — cosine 0.999
|
||||
- KV cache read (paged cache → fp8 dequant → BF16) — cosine 0.999
|
||||
- Decode attention (1 query vs N cached KVs) — cosine 0.9998
|
||||
- Full pipeline (inv RoPE + o_a BMM + o_b) — cosine 0.996–0.999
|
||||
- All 5 layer types (C128A, C4A, SWA) — cosine ≥0.996
|
||||
|
||||
## Architecture: DeepSeek-V4-Pro
|
||||
|
||||
**CSA + HCA + mHC** (NOT MLA — vLLM misnames it "MLA" in code):
|
||||
**MegaMoE (384 experts, top-6) with CSA + HCA + mHC:**
|
||||
|
||||
- **CSA (Compress Ratio 4)**: Compressed Sparse Attention — KV compressed 4x with overlap (coff=2). Indexer finds per-layer top-k.
|
||||
- **HCA (Compress Ratio 128)**: Heavily Compressed Attention — KV compressed 128x. Top-k indices pre-computed during metadata build.
|
||||
@@ -76,42 +74,44 @@ Layer 0: 128 (HCA) Layer 1: 128 (HCA) Layer 2: 4 (CSA) Layer 3: 128 (HCA)
|
||||
Layer 4: 4 (CSA) ...alternating 4/128... Layer 60: 0 (SWA)
|
||||
```
|
||||
|
||||
**Expert intermediate size: 3072** (NOT 18432 — that's 6×3072 for top-6)
|
||||
|
||||
**DeepGEMM MegaMoE**: DeepSeek's persistent grouped GEMM for MoE uses TMA tensormap updates per expert with variable block_m (16-192) based on expected tokens per expert. Our CuTeDSL runner uses `run_nvfp4_grouped_gemm` (simpler, but proven correct in standalone tests).
|
||||
|
||||
## Current Status
|
||||
|
||||
### ✅ Working (verified on B200 standalone tests)
|
||||
### ✅ Verified (B200 venv, real weights, zero NaN)
|
||||
|
||||
| Component | Test | Cosine vs BF16 |
|
||||
|-----------|------|----------------|
|
||||
| CuTeDSL NVFP4 Linear (q_a, kv, q_b, wo_b) | `test_full_layer_b200.py` | 0.994+ |
|
||||
| CuTeDSL NVFP4 MoE (L1 gate+up, SiLU, L2 down) | `layertest.py` | 0.988 |
|
||||
| FP8 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9997 |
|
||||
| NVFP4 KV quantize/dequant | `test_kv_cache_b200.py` | 0.9943 |
|
||||
| Paged KV cache read/write | `test_kv_cache_b200.py` | 1.0 |
|
||||
| FP8 KV → full attention | `test_kv_cache_b200.py` | 0.9997 |
|
||||
| CSA sparse attention (cr=4) | `test_sparse_attn_b200.py` | works, no NaN |
|
||||
| HCA sparse attention (cr=128) | `test_sparse_attn_b200.py` | works, no NaN |
|
||||
| Merged CSA+SWA attention | `test_sparse_attn_b200.py` | works, no NaN |
|
||||
| Full attention pipeline (all layer types) | `test_v4_attention_b200.py` | 0.981–0.995 |
|
||||
| RoPE (GPT-J) | `test_v4_attention_b200.py` | works |
|
||||
| Inverse RoPE + o_a BMM | `test_v4_attention_b200.py` | works |
|
||||
| KV cache write + decode attention | `test_decode_attention_b200.py` | 0.9998 |
|
||||
| Decode vs prefill consistency (5 layers) | `test_decode_vs_prefill_b200.py` | 0.996–0.999 |
|
||||
| E2E 61-layer model (shared experts) | `test_e2e_decode_b200.py` | healthy logits |
|
||||
| MoE runner (grouped GEMM, 16 experts) | `test_moe_runner_nan_b200.py` | no NaN, all sizes |
|
||||
| Full layer (attention + MoE) | `test_full_layer_nan_b200.py` | no NaN |
|
||||
| Multi-layer chain (3 layers) | `test_full_layer_nan_b200.py` | no NaN |
|
||||
|
||||
### 🔧 Needs CuTeDSL Kernels
|
||||
### ❌ Container — NaN in vLLM compiled execution
|
||||
|
||||
1. **Attention Q×K^T** — BF16 matmul works standalone, but NVFP4 GEMM too lossy (cosine 0.86). Keep Q×K in BF16.
|
||||
2. **KV Cache Write** — need CuTeDSL kernel that does: RoPE → fp8 quant → paged cache insert
|
||||
3. **KV Cache Read** — need CuTeDSL kernel that does: paged cache read → fp8 dequant
|
||||
4. **Fused Q-norm + RoPE** — currently pure PyTorch (works, slow)
|
||||
5. **Fused inverse RoPE + o_a BMM** — currently pure PyTorch (works)
|
||||
The container produces empty/garbage output. Debug logs show NaN in `hidden_states` from the first forward pass. **The NaN is NOT from our kernels** — it comes from vLLM's compiled execution infrastructure (see CURRENT_BUG.md for full investigation).
|
||||
|
||||
Most likely sources:
|
||||
1. `attn_gemm_parallel_execute` — fused parallel GEMM (NOT our CuTeDSL kernel)
|
||||
2. `fused_q_kv_rmsnorm` — CUDA kernel that may produce NaN on Blackwell
|
||||
3. Weight packing during model loading
|
||||
4. `torch.compile` + cudagraph interaction with CuTeDSL buffers
|
||||
|
||||
### ❌ Does NOT Work
|
||||
|
||||
- **NVFP4 Q×K^T GEMM** — cosine 0.86, too lossy for attention scores. Keep attention in BF16.
|
||||
- **Patching vLLM's FlashMLA path** — house of cards. Don't do it.
|
||||
- **Pure PyTorch SDPA in the container** — produces garbage because the KV cache isn't written and the pipeline is broken.
|
||||
|
||||
## Container Status
|
||||
|
||||
The container builds and starts successfully. The server accepts requests and generates tokens. But the output is empty/garbage because the Blackwell attention path is broken. Multiple patches were applied to get this far (KV cache page sizes, FlashMLA alignment, softmax_scale, compressor cache), but the fundamental problem remains: **you cannot half-ass the attention pipeline**.
|
||||
|
||||
## Test Files
|
||||
|
||||
@@ -120,11 +120,15 @@ The container builds and starts successfully. The server accepts requests and ge
|
||||
| `tests/test_full_layer_b200.py` | All NVFP4 projections vs BF16 | ✅ 0.994+ |
|
||||
| `tests/layertest.py` | MoE layer test | ✅ 0.988 |
|
||||
| `tests/cudagraph_test.py` | CUDAGraph compatibility | ✅ PASS |
|
||||
| `tests/test_csa_attention_b200.py` | Full attention with SDPA | ✅ 0.988 |
|
||||
| `tests/test_v4_attention_b200.py` | All 3 layer types (SWA, C128A, C4A) | ✅ 0.981-0.995 |
|
||||
| `tests/test_kv_cache_b200.py` | FP8/NVFP4 KV cache + paged cache | ✅ 0.9997 |
|
||||
| `tests/test_sparse_attn_b200.py` | CSA/HCA sparse + SWA merged | ✅ works |
|
||||
| `tests/test_nvfp4_attn_gemm_b200.py` | NVFP4 Q×K^T GEMM | ❌ 0.86 (too lossy) |
|
||||
| `tests/test_decode_attention_b200.py` | Prefill + decode with KV cache | ✅ 0.9998 |
|
||||
| `tests/test_decode_vs_prefill_b200.py` | Decode vs prefill consistency | ✅ 0.996-0.999 |
|
||||
| `tests/test_e2e_decode_b200.py` | 61-layer E2E (shared experts) | ✅ healthy logits |
|
||||
| `tests/test_moe_nan_b200.py` | Single expert NaN check | ✅ no NaN |
|
||||
| `tests/test_moe_runner_nan_b200.py` | MoE grouped GEMM NaN check | ✅ no NaN |
|
||||
| `tests/test_full_layer_nan_b200.py` | Full layer + multi-layer NaN check | ✅ no NaN |
|
||||
|
||||
## Project Structure
|
||||
|
||||
@@ -133,63 +137,42 @@ nvfp4-megamoe-kernel/
|
||||
├── cutedsl/ # CuTeDSL kernel + bridge layer
|
||||
│ ├── bridge.py # Tensor layout conversion, quantization, kernel launch
|
||||
│ ├── nvfp4_linear.py # CuTeDSLNvfp4Linear — NVFP4 GEMM runner
|
||||
│ ├── moe_pipeline.py # Full MoE pipeline (L1→SiLU→L2→scatter)
|
||||
│ ├── shared_expert_pipeline.py # Shared expert pipeline
|
||||
│ ├── csa_attention.py # CSA/HCA attention (BF16 SDPA — needs CuTeDSL)
|
||||
│ ├── runner.py # CuTeDSLMoERunner — grouped GEMM MoE
|
||||
│ ├── blackwell_attention.py # KV cache + attention (standalone, works)
|
||||
│ ├── csa_attention.py # CSA/HCA attention (BF16 SDPA)
|
||||
│ ├── custom_ops.py # torch.autograd wrappers
|
||||
│ └── kernel/moe/ # NVIDIA's ScaledGroupedGemmKernel
|
||||
├── vllm/ # vLLM integration
|
||||
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner
|
||||
│ ├── nvfp4_cutedsl.py # CuTeDSLMoERunner (vLLM wrapper)
|
||||
│ ├── cutedsl_quant_method.py # CuTeDSLNvfp4LinearMethod
|
||||
│ ├── kernels/linear/nvfp4/cutedsl.py # vLLM kernel registration
|
||||
│ └── patches/
|
||||
│ ├── deepseek_v4_attention.py # Attention patch (Blackwell dispatch)
|
||||
│ ├── deepseek_compressor.py # Compressor patch (skip fused kernel on Blackwell)
|
||||
│ ├── patch_kv_cache_utils.py # KV cache page size fix
|
||||
│ ├── patch_swa_cache.py # SWA cache alignment fix
|
||||
│ ├── patch_indexer_cache.py # Indexer cache alignment fix
|
||||
│ ├── patch_compressor_cache.py # Compressor cache alignment fix
|
||||
│ └── layers/
|
||||
│ ├── csa_attention.py # BF16 SDPA (TEMPORARY — needs CuTeDSL)
|
||||
│ └── ...
|
||||
│ ├── csa_attention.py # BF16 SDPA + KV cache (our Blackwell path)
|
||||
│ └── deepseek_compressor.py # Skip fused kernel on Blackwell
|
||||
├── tests/ # Standalone tests (run on B200 venv)
|
||||
└── Dockerfile # Container build
|
||||
├── Dockerfile # Container build
|
||||
├── README.md # This file
|
||||
└── CURRENT_BUG.md # Current bug investigation
|
||||
```
|
||||
|
||||
## Plan
|
||||
|
||||
### Phase 1: MoE Kernel ✅ DONE
|
||||
### Phase 2: NVFP4 Linear Kernels ✅ DONE
|
||||
### Phase 3: vLLM Integration ✅ DONE (NVFP4 linear + MoE working)
|
||||
### Phase 3: Attention Pipeline ✅ DONE (standalone, all tests pass)
|
||||
### Phase 4: vLLM Integration 🔧 IN PROGRESS — blocked by NaN from vLLM infrastructure
|
||||
|
||||
### Phase 4: CuTeDSL Attention Backend 🔧 NEXT — BUILD THE KERNELS
|
||||
**Current blocker:** NaN in the vLLM container's compiled execution. Our kernels produce zero NaN standalone. The NaN comes from vLLM's `attn_gemm_parallel_execute` or `fused_q_kv_rmsnorm` CUDA kernels, weight packing, or torch.compile interaction.
|
||||
|
||||
**STOP. READ THIS.**
|
||||
|
||||
Do NOT touch the vLLM container until ALL of these kernels pass standalone tests on the B200 venv. The container is a 14-minute build cycle. The venv gives you instant feedback. TEST FIRST.
|
||||
|
||||
**Kernels to build (in order):**
|
||||
|
||||
1. **KV Cache Write**: BF16 KV → apply RoPE → quantize to fp8 → write to paged cache
|
||||
- Test: compare against BF16 reference (cosine >= 0.98 after dequant)
|
||||
|
||||
2. **KV Cache Read**: paged cache → dequant fp8 → BF16 KV with RoPE
|
||||
- Test: write then read back, cosine >= 0.99
|
||||
|
||||
3. **BF16 Attention**: Q (with RoPE) × K^T → softmax → attn × V
|
||||
- Keep this in BF16 (NVFP4 is too lossy for attention scores)
|
||||
- Handle CSA sparse gather (attend to top-k indexed positions)
|
||||
- Handle HCA sparse gather (attend to 1/128 positions)
|
||||
- Handle SWA (sliding window, full causal within window)
|
||||
- Test: compare against PyTorch SDPA reference (cosine >= 0.99)
|
||||
|
||||
4. **Full Attention Pipeline**: KV cache read → attention → inverse RoPE → o_a BMM
|
||||
- Wire everything together
|
||||
- Test: compare against BF16 reference (cosine >= 0.98)
|
||||
|
||||
5. **vLLM Backend**: Wrap as a proper AttentionBackend subclass
|
||||
- Override `DeepseekSparseSWABackend` on Blackwell
|
||||
- Handle the metadata, slot mapping, cache format
|
||||
- ONLY THEN test in the container
|
||||
**Next:**
|
||||
1. Install vllm in B200 venv, test the exact parallel GEMM path
|
||||
2. Test with torch.compile disabled in the container
|
||||
3. Add NaN checks inside the parallel GEMM wrapper
|
||||
4. If the parallel GEMM is the source, replace it with our CuTeDSL kernels (path of least resistance)
|
||||
|
||||
### Phase 5: Production
|
||||
- End-to-end benchmarking
|
||||
|
||||
Reference in New Issue
Block a user