Add MoE NaN reproduction test, update CURRENT_BUG.md with NaN tracing and test plan
This commit is contained in:
108
CURRENT_BUG.md
108
CURRENT_BUG.md
@@ -1,31 +1,91 @@
|
||||
# CURRENT_BUG.md — DeepSeek-V4 Blackwell NVFP4
|
||||
|
||||
## Status: KV CACHE PIPELINE VERIFIED ✅
|
||||
## Status: NaN IN MOE — ROOT CAUSE UNKNOWN
|
||||
|
||||
### What's Fixed
|
||||
- **Root cause identified**: vLLM's `_attention_impl_blackwell` never writes KV to the paged cache, so decode produces garbage because it can't access prior tokens' KV.
|
||||
- **Solution built and tested**: `cutedsl/blackwell_attention.py` + `vllm/patches/layers/csa_attention.py` — KV cache write/read pipeline using fp8 quantization.
|
||||
### Current Symptom
|
||||
- vLLM container starts, model loads, server accepts requests
|
||||
- **Output is empty** — model generates tokens but they decode to nothing
|
||||
- Debug logs show **NaN in hidden_states** entering the attention from the FIRST forward pass
|
||||
- NaN propagates through all 61 layers → all outputs are NaN → garbage tokens
|
||||
- Both C128A (cr=128) and C4A (cr=4) layers have NaN in their inputs
|
||||
|
||||
### Test Results (B200 venv, all passing)
|
||||
### NaN Tracing
|
||||
```
|
||||
Layer 0 (C128A): hidden_states input → ??? → NaN in attention input
|
||||
Layer 1-59 (C4A): NaN in attention input (propagated)
|
||||
Layer 60 (SWA): NaN in attention input (propagated)
|
||||
```
|
||||
The NaN originates BEFORE the attention — it's in the MoE output that feeds into the next layer.
|
||||
|
||||
| Test | Result |
|
||||
|------|--------|
|
||||
| KV cache roundtrip (fp8 quant → dequant) | 0.999+ cosine |
|
||||
| Decode attention (1 query vs N cached KVs) | 0.9998 cosine |
|
||||
| Full pipeline (inv RoPE + o_a + o_b) | 0.996-0.999 cosine |
|
||||
| All 5 layer types (C128A, C4A, SWA) | ≥0.996 cosine |
|
||||
| E2E 61-layer model (shared experts) | Healthy logits, consistent tokens |
|
||||
| Multi-step decode (3 steps) | 0.999+ cosine each step |
|
||||
### Architecture: DeepSeek-V4 MegaMoE
|
||||
- **384 experts, top-6 routing** — this is a "MegaMoE" architecture
|
||||
- DeepGEMM has a specialized `mega_moe.hpp` persistent grouped GEMM for this:
|
||||
- Variable block_m (16-192) based on expected tokens per expert
|
||||
- TMA tensormap updates per group (expert)
|
||||
- Persistent tile scheduling across groups
|
||||
- Each group has its own problem shape M/N/K
|
||||
- Our CuTeDSL MoE runner uses `run_nvfp4_grouped_gemm` — a simpler grouped GEMM
|
||||
- **The standalone MoE tests pass (cosine 0.988) but may not exercise the same shapes/paths as vLLM**
|
||||
|
||||
### What's Next
|
||||
1. Test in vLLM container (build_and_run.sh)
|
||||
2. Handle CSA/HCA sparse attention in the Blackwell path (currently using full attention for all layers)
|
||||
3. Add routed MoE experts (currently shared experts only)
|
||||
4. Performance optimization (vectorized paged KV, Triton kernels)
|
||||
### What's Been Verified (B200 venv, all passing)
|
||||
| Component | Test | Result |
|
||||
|-----------|------|--------|
|
||||
| NVFP4 Linear (q_a, kv, q_b, o_b) | cosine per projection | 0.998-1.0 |
|
||||
| NVFP4 MoE (L1 gate+up, L2 down) | cosine per layer | 0.988 |
|
||||
| KV cache roundtrip (fp8) | cosine | 0.999 |
|
||||
| Decode attention (1 query vs N KV) | cosine | 0.9998 |
|
||||
| Full pipeline (inv RoPE + o_a + o_b) | cosine | 0.996-0.999 |
|
||||
| All 5 layer types | cosine | ≥0.996 |
|
||||
| E2E 61-layer (shared experts) | logits std=3.16 | reasonable |
|
||||
| CSA sparse attention (C4A) | cosine | 0.974 |
|
||||
| CSA sparse attention (C128A) | cosine | 0.668 (avg-pooled KV) |
|
||||
| Multi-step decode | cosine | 0.999 |
|
||||
|
||||
### Architecture
|
||||
- KV latent: (T, HD=512) shared across 128 Q heads
|
||||
- KV Cache: fp8_e4m3 paged cache with per-token inverse scale
|
||||
- Attention: BF16 (NVFP4 too lossy for Q×K^T)
|
||||
- Prefill: causal SDPA on raw KV
|
||||
- Decode: read all cached KV → fp8 dequant → SDPA → output
|
||||
### What's Been Fixed in vLLM Integration
|
||||
1. Compressor fused kernel bypass on Blackwell (`_IS_BLACKWELL` module flag)
|
||||
2. Double Q normalization removed (fused_qnorm only does RoPE now)
|
||||
3. RoPE sin slice bug fixed (`half:2*half` not `half:`)
|
||||
4. fp8 dequant fix (use `kv_dequantize_fp8` not `.to(bf16)`)
|
||||
5. Wrapper attribute access (`self.mla_attn.kv_cache` etc.)
|
||||
6. Paged KV decode using `decode_swa_indices` from metadata
|
||||
7. `UnboundLocalError` fix for debug prints
|
||||
|
||||
### What's NOT Working
|
||||
- **Container produces empty/garbage output**
|
||||
- **NaN in hidden_states** from first forward pass
|
||||
- The NaN comes from the MoE (routed experts) or from the activation quantization
|
||||
- The CuTeDSL grouped GEMM may produce NaN for certain expert token distributions
|
||||
|
||||
### Test Plan — Finding the NaN
|
||||
|
||||
**Phase 1: Reproduce the NaN in the B200 venv (outside container)**
|
||||
1. Test `CuTeDSLMoERunner.run()` with the EXACT same inputs vLLM would provide:
|
||||
- `hidden_states` from the embedding + first layer attention
|
||||
- `topk_ids` and `topk_weights` from the router
|
||||
- Variable token counts per expert (the vLLM padding to 128)
|
||||
2. Test with 1 token (decode), 8 tokens (small prefill), and padded shapes
|
||||
3. Check for NaN after L1 GEMM, after SiLU activation, after L2 GEMM
|
||||
4. Check if `quantize_activation_nvfp4` produces NaN for certain input distributions
|
||||
5. Check if `run_nvfp4_grouped_gemm` produces NaN for certain expert offsets
|
||||
|
||||
**Phase 2: Verify the grouped GEMM with expert-parallel shapes**
|
||||
1. Test with 48 experts (EP8, 384/8), 1-8 tokens, top-6
|
||||
2. Test with padding to 128 rows per expert
|
||||
3. Check if the GEMM handles zero-token experts correctly
|
||||
4. Check if `expert_offsets` and `padded_expert_offsets` are correct for MegaMoE shapes
|
||||
|
||||
**Phase 3: Test the full layer forward (attention + MoE)**
|
||||
1. Run layer 0 (C128A) with real weights, check output for NaN
|
||||
2. Run layer 2 (C4A) with real weights, check output for NaN
|
||||
3. If NaN appears, bisect: which component produces it?
|
||||
|
||||
**Phase 4: Fix and verify**
|
||||
1. Fix the NaN source
|
||||
2. Run all B200 venv tests
|
||||
3. Build container, test with real inference
|
||||
4. Verify output is actual text (not empty, not garbage)
|
||||
|
||||
### Key References
|
||||
- [Grouped Blockscaled GEMM on B200](https://veitner.bearblog.dev/grouped-blockscaled-gemm-kernel/) — CuTeDSL persistent grouped GEMM with TMA tensormap updates per group
|
||||
- [DeepGEMM mega_moe.hpp](https://github.com/deepseek-ai/DeepGEMM/blob/main/csrc/jit_kernels/heuristics/mega_moe.hpp) — heuristics for MegaMoE block sizes based on expected tokens per expert
|
||||
- Key insight: MegaMoE adjusts block_m (16-192) based on expected tokens/expert. For decode (few tokens), block_m=16-32. For prefill, block_m=192.
|
||||
|
||||
231
tests/test_moe_nan_b200.py
Normal file
231
tests/test_moe_nan_b200.py
Normal file
@@ -0,0 +1,231 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
DeepSeek-V4 MoE NaN Reproduction Test
|
||||
|
||||
Finds where NaN originates in the MoE forward pass.
|
||||
Tests the EXACT CuTeDSLMoERunner code path used by vLLM.
|
||||
|
||||
This test is the FIRST step: if the MoE produces NaN, the entire model
|
||||
produces garbage. We need to find the NaN source before anything else matters.
|
||||
|
||||
Test plan:
|
||||
1. Load MoE weights for a single layer
|
||||
2. Run the CuTeDSLMoERunner with various token counts and routing patterns
|
||||
3. Check for NaN at each step: quantize → L1 GEMM → SiLU → L2 GEMM → combine
|
||||
4. Specifically test with MegaMoE shapes: 48 experts (EP8), padded to 128 rows
|
||||
|
||||
Usage (on B200):
|
||||
cd /root/nvfp4-megamoe-kernel
|
||||
PYTHONPATH=/root/nvfp4-megamoe-kernel tests/venv/bin/python tests/test_moe_nan_b200.py
|
||||
"""
|
||||
|
||||
import sys, os, json, torch, torch.nn.functional as F
|
||||
from safetensors import safe_open
|
||||
|
||||
REPO = "/root/nvfp4-megamoe-kernel"
|
||||
sys.path.insert(0, REPO)
|
||||
MODEL = "/root/nvidia-meeting/DeepSeek-V4-Pro-NVFP4"
|
||||
DEV = "cuda:0"
|
||||
|
||||
H = 7168; NH = 128; HD = 512; NOPE = 448; ROPE = 64
|
||||
QL = 1536; OL = 1024; OG = 16; HPG = NH // OG
|
||||
INTERMEDIATE = 18432 # DeepSeek-V4 MoE intermediate size
|
||||
NUM_EXPERTS = 48 # EP8: 384/8
|
||||
TOPK = 6
|
||||
EPS = 1e-6; WINDOW = 128; SCALE = HD ** -0.5
|
||||
|
||||
_cache = {}
|
||||
def P(k, wm, md):
|
||||
if k in _cache: return _cache[k]
|
||||
with safe_open(os.path.join(md, wm[k]), framework="pt") as f:
|
||||
t = f.get_tensor(k)
|
||||
_cache[k] = t
|
||||
return t
|
||||
|
||||
def rms(x, w, eps=1e-6):
|
||||
v = x.float().pow(2).mean(-1, keepdim=True)
|
||||
return (w.float() * (x * torch.rsqrt(v+eps)).float()).to(x.dtype)
|
||||
|
||||
|
||||
def test_moe_layer(layer_id=2):
|
||||
"""Test the MoE forward pass for a single layer, checking for NaN at each step."""
|
||||
from cutedsl.runner import CuTeDSLMoERunner
|
||||
|
||||
torch.cuda.set_device(0)
|
||||
torch.manual_seed(42)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with open(os.path.join(MODEL, "model.safetensors.index.json")) as f:
|
||||
wm = json.load(f)["weight_map"]
|
||||
G = lambda k: P(k, wm, MODEL).to(DEV)
|
||||
|
||||
p = f"model.layers.{layer_id}"
|
||||
m = f"{p}.mlp"
|
||||
|
||||
# Load embedding for input
|
||||
emb = G("model.embed_tokens.weight")
|
||||
fnorm = G(f"{p}.post_attention_layernorm.weight")
|
||||
|
||||
# MoE weights
|
||||
# Gate/up (w13): (E, 2*intermediate, hidden//2) uint8
|
||||
# Down (w2): (E, hidden, intermediate//2) uint8
|
||||
w13_w = G(f"{m}.experts.w13_weight") # or gate_proj + up_proj
|
||||
w13_sf = G(f"{m}.experts.w13_weight_scale")
|
||||
w13_gs = G(f"{m}.experts.w13_weight_scale_2")
|
||||
w2_w = G(f"{m}.experts.w2_weight")
|
||||
w2_sf = G(f"{m}.experts.w2_weight_scale")
|
||||
w2_gs = G(f"{m}.experts.w2_weight_scale_2")
|
||||
swiglu_limit = None
|
||||
|
||||
# Shared expert
|
||||
se_gate_w = G(f"{m}.shared_experts.gate_proj.weight")
|
||||
se_gate_sf = G(f"{m}.shared_experts.gate_proj.weight_scale")
|
||||
se_gate_gs = G(f"{m}.shared_experts.gate_proj.weight_scale_2")
|
||||
se_up_w = G(f"{m}.shared_experts.up_proj.weight")
|
||||
se_up_sf = G(f"{m}.shared_experts.up_proj.weight_scale")
|
||||
se_up_gs = G(f"{m}.shared_experts.up_proj.weight_scale_2")
|
||||
se_down_w = G(f"{m}.shared_experts.down_proj.weight")
|
||||
se_down_sf = G(f"{m}.shared_experts.down_proj.weight_scale")
|
||||
se_down_gs = G(f"{m}.shared_experts.down_proj.weight_scale_2")
|
||||
|
||||
print(f" w13_weight shape: {w13_w.shape}, dtype: {w13_w.dtype}")
|
||||
print(f" w2_weight shape: {w2_w.shape}, dtype: {w2_w.dtype}")
|
||||
print(f" w13_gs shape: {w13_gs.shape}")
|
||||
print(f" w2_gs shape: {w2_gs.shape}")
|
||||
print(f" w13_gs sample: {w13_gs[:5].tolist()}")
|
||||
print(f" w2_gs sample: {w2_gs[:5].tolist()}")
|
||||
|
||||
# Check for NaN in weights
|
||||
print(f" w13 NaN: {torch.isnan(w13_w.float()).any()}")
|
||||
print(f" w2 NaN: {torch.isnan(w2_w.float()).any()}")
|
||||
print(f" w13_sf NaN: {torch.isnan(w13_sf.float()).any()}")
|
||||
print(f" w2_sf NaN: {torch.isnan(w2_sf.float()).any()}")
|
||||
print(f" w13_gs NaN: {torch.isnan(w13_gs).any()}")
|
||||
print(f" w2_gs NaN: {torch.isnan(w2_gs).any()}")
|
||||
|
||||
# Create the MoE runner
|
||||
num_local_experts = w13_w.shape[0]
|
||||
hidden_size = w13_w.shape[2] * 2 # hidden//2 packed → *2 for fp4
|
||||
intermediate_size = w13_w.shape[1] // 2 # 2*intermediate // 2
|
||||
|
||||
print(f"\n num_local_experts: {num_local_experts}")
|
||||
print(f" hidden_size: {hidden_size}")
|
||||
print(f" intermediate_size: {intermediate_size}")
|
||||
|
||||
runner = CuTeDSLMoERunner(
|
||||
num_experts=num_local_experts,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
max_num_tokens=8192,
|
||||
top_k=TOPK,
|
||||
device=str(DEV),
|
||||
)
|
||||
|
||||
# Prepare weights
|
||||
l1_fp4 = w13_w.view(torch.float4_e2m1fn_x2)
|
||||
l2_fp4 = w2_w.view(torch.float4_e2m1fn_x2)
|
||||
l1_sf = w13_sf.to(torch.float8_e4m3fn) if w13_sf.dtype != torch.float8_e4m3fn else w13_sf
|
||||
l2_sf = w2_sf.to(torch.float8_e4m3fn) if w2_sf.dtype != torch.float8_e4m3fn else w2_sf
|
||||
|
||||
runner.prepare_weights_from_stacked(
|
||||
l1_fp4, l1_sf, w13_gs.tolist(),
|
||||
l2_fp4, l2_sf, w2_gs.tolist(),
|
||||
)
|
||||
|
||||
# Test with various token counts
|
||||
test_cases = [
|
||||
("1 token (decode)", 1),
|
||||
("4 tokens", 4),
|
||||
("8 tokens", 8),
|
||||
("16 tokens", 16),
|
||||
]
|
||||
|
||||
for desc, num_tokens in test_cases:
|
||||
print(f"\n --- {desc} ---")
|
||||
token_ids = torch.randint(1, 1000, (num_tokens,), dtype=torch.long, device=DEV)
|
||||
hidden = emb[token_ids]
|
||||
normed = rms(hidden, fnorm, EPS)
|
||||
|
||||
print(f" Input: amax={normed.amax():.4f} NaN={torch.isnan(normed).any()}")
|
||||
|
||||
# Create routing (random top-6 from num_local_experts)
|
||||
topk_ids = torch.randint(0, num_local_experts, (num_tokens, TOPK), device=DEV)
|
||||
topk_weights = torch.softmax(torch.randn(num_tokens, TOPK, device=DEV), dim=-1)
|
||||
|
||||
with torch.no_grad():
|
||||
result = runner.run(normed, topk_weights, topk_ids)
|
||||
|
||||
print(f" Output: amax={result.amax():.4f} NaN={torch.isnan(result).any()}")
|
||||
if torch.isnan(result).any():
|
||||
# Count NaN rows
|
||||
nan_rows = torch.isnan(result).any(dim=1).sum().item()
|
||||
print(f" NaN rows: {nan_rows}/{num_tokens}")
|
||||
|
||||
# Check if shared expert also produces NaN
|
||||
from cutedsl.nvfp4_linear import CuTeDSLNvfp4Linear
|
||||
def make_runner(w, sf, gs_t, inf, outf):
|
||||
fp4 = w.view(torch.float4_e2m1fn_x2).permute(1,0).contiguous()
|
||||
s = sf.to(torch.float8_e4m3fn) if sf.dtype != torch.float8_e4m3fn else sf
|
||||
s = s.permute(1,0).contiguous()
|
||||
gs = gs_t.max().item() if gs_t.numel() > 1 else gs_t.item()
|
||||
r = CuTeDSLNvfp4Linear(in_features=inf, out_features=outf, max_num_tokens=8192, device=str(w.device))
|
||||
r.fp4 = [fp4]; r.sf = [s]; r.gs = [gs]
|
||||
r.finalize_weights(); r._ensure_initialized()
|
||||
return r
|
||||
|
||||
# Shared expert only
|
||||
r_gate = make_runner(se_gate_w, se_gate_sf, se_gate_gs, H, se_gate_w.shape[0])
|
||||
r_up = make_runner(se_up_w, se_up_sf, se_up_gs, H, se_up_w.shape[0])
|
||||
r_down = make_runner(se_down_w, se_down_sf, se_down_gs, INTERMEDIATE, se_down_w.shape[0])
|
||||
|
||||
with torch.no_grad():
|
||||
gate_out = r_gate.run(normed)
|
||||
up_out = r_up.run(normed)
|
||||
activated = F.silu(gate_out) * up_out
|
||||
se_result = r_down.run(activated)
|
||||
|
||||
print(f" Shared expert: amax={se_result.amax():.4f} NaN={torch.isnan(se_result).any()}")
|
||||
|
||||
del r_gate, r_up, r_down
|
||||
|
||||
# Test with exactly the vLLM padding pattern
|
||||
print(f"\n --- vLLM padding test (8 tokens, top-6, expert offsets) ---")
|
||||
num_tokens = 8
|
||||
token_ids = torch.randint(1, 1000, (num_tokens,), dtype=torch.long, device=DEV)
|
||||
hidden = emb[token_ids]
|
||||
normed = rms(hidden, fnorm, EPS)
|
||||
topk_ids = torch.randint(0, num_local_experts, (num_tokens, TOPK), device=DEV)
|
||||
topk_weights = torch.softmax(torch.randn(num_tokens, TOPK, device=DEV), dim=-1)
|
||||
|
||||
with torch.no_grad():
|
||||
result = runner.run(normed, topk_weights, topk_ids)
|
||||
|
||||
print(f" Output: amax={result.amax():.4f} NaN={torch.isnan(result).any()}")
|
||||
print(f" Output sample (first 10): {result[0, :10].tolist()}")
|
||||
|
||||
del runner
|
||||
torch.cuda.empty_cache()
|
||||
_cache.clear()
|
||||
|
||||
|
||||
def main():
|
||||
print("=" * 70)
|
||||
print(" DeepSeek-V4 MoE NaN Reproduction Test")
|
||||
print(" Finds where NaN originates in the MoE forward pass")
|
||||
print("=" * 70)
|
||||
|
||||
test_moe_layer(layer_id=2) # C4A layer
|
||||
|
||||
print(f"\n{'='*70}")
|
||||
print(f" If NaN is found, bisect by testing each step:")
|
||||
print(f" 1. quantize_activation_nvfp4(input)")
|
||||
print(f" 2. run_nvfp4_grouped_gemm(L1)")
|
||||
print(f" 3. SiLU(gate) * up")
|
||||
print(f" 4. quantize_activation_nvfp4(activated)")
|
||||
print(f" 5. run_nvfp4_grouped_gemm(L2)")
|
||||
print(f" 6. scatter_add combine")
|
||||
print(f"{'='*70}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user