From a51edd238e5a8141b38cb1ec4ba2c592dfe2a83b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 18 May 2026 17:56:19 +0000 Subject: [PATCH] Add post-quant-init forward hook to fix attention NVFP4 The key insight: process_weights_after_loading runs AFTER load_weights and sets up FlashInferCutlassNvFp4LinearKernel with broken input_global_scale_inv. Any fix inside load_weights gets overwritten. Solution: register a one-shot forward pre-hook that runs on the first forward call (guaranteed after all init). It dequantizes attention NVFP4 weights to BF16 and replaces quant_method with UnquantizedLinearMethod. Since process_weights_after_loading already ran, our changes won't be overwritten. Standalone test confirmed: all attention weights produce valid non-NaN output when dequantized to BF16. --- CURRENT_BUG.md | 120 +++++++++++++++++++++++++----------- vllm/patches/deepseek_v4.py | 61 ++++++++++++++++++ 2 files changed, 144 insertions(+), 37 deletions(-) diff --git a/CURRENT_BUG.md b/CURRENT_BUG.md index 17f8057b..44205163 100644 --- a/CURRENT_BUG.md +++ b/CURRENT_BUG.md @@ -1,58 +1,104 @@ -# Current Bug: vLLM produces NaN from layer 0 +# Current Bug: vLLM produces empty/garbage output -**Status:** Active debugging — BF16 dequant fix in progress +**Status:** Weights confirmed good — bug is in vLLM's quant pipeline for attention **Date:** 2026-05-18 ## Symptom -- vLLM server starts, loads model, but every inference produces NaN logits -- Diagnostic prints show **NaN from layer 0 onward** — no layer ever produces valid output -- Empty content in chat completions, NaN in logprobs +- vLLM server starts, loads model, processes requests (200 OK) +- Chat completions return `content: ""` with `finish_reason: "length"` +- 20 completion tokens generated but all produce empty/NaN logits +- With enforce-eager + diagnostics: **NaN from layer 0 onward** on real requests -## Root Cause (in progress) +## ✅ Confirmed: Weights produce valid output -**The attention NVFP4 linear layers produce NaN immediately.** +Standalone test (`test_attn_moe_chain.py`) running directly on B200: -The attention projections go through vLLM's `FlashInferCutlassNvFp4LinearKernel` which uses checkpoint `input_scale` as the activation global scale for `scaled_fp4_quant()`. The checkpoint `input_scale` values are wrong for this use case, causing overflow → NaN. +| Step | Operation | amax | NaN? | +|------|-----------|------|------| +| 1 | Embed tokens | 1.27 | No | +| 2 | hc_mult expansion | 1.27 | No | +| 3 | RMSNorm | 0.20 | No | +| 4 | q_a_proj (NVFP4→BF16 dequant + matmul) | 0.50 | No | +| 5 | kv_proj (NVFP4→BF16 dequant + matmul) | 1.30 | No | +| 6 | q_norm + kv_norm | 0.11 / 1.87 | No | +| 7 | q_b_proj (NVFP4→BF16 dequant + matmul) | 1.10 | No | +| 8 | MoE CuTeDSL runner (with warmup gs) | cosine 0.988 | No | -### What we've tried +**Every step produces valid, non-NaN, non-zero output.** The problem is NOT the weights. -1. ✅ **MoE kernel is NOT the problem** — `test_runner_vllm_style.py` with warmup gs gives cosine 0.988, no NaN -2. ❌ **Dequant ALL attn projections to BF16** — crashed: `wo_a.weight_scale_inv` missing (fp8_einsum needs it) -3. ❌ **Dequant all except wo_a (keep wo_a as FP8)** — still NaN from layer 0. `wq_a` and `wkv` don't exist as separate attrs — they're **fused as `fused_wqa_wkv`** -4. ❌ **Changed to dequant `fused_wqa_wkv`** — still NaN from layer 0. Debug prints added to check if the attrs are actually found. +## ❌ Root Cause: vLLM's `process_weights_after_loading` breaks attention -### Current theory +### The timeline -The BF16 dequant code may not be finding `fused_wqa_wkv` on the attention module, so it silently skips the most important projection. Debug logging added in latest commit to verify. +1. `load_weights()` → our `_convert_nvfp4_post_load()` runs +2. `process_weights_after_loading()` → vLLM's quant method runs AFTER, **overwriting our fixes** +3. `FlashInferCutlassNvFp4LinearKernel` gets set up with broken `input_global_scale_inv` -### Attention architecture (DeepSeek V4 MLA) +### What the quant method does -- `fused_wqa_wkv` — MergedColumnParallelLinear (q_a + kv fused) -- `wq_b` — ColumnParallelLinear (second Q projection after RoPE) -- `wo_a` — ColumnParallelLinear (FP8 via fp8_einsum, weight-only, NO input_scale) -- `wo_b` — ColumnParallelLinear (final output projection) -- `compressor` — already handled (reconstructed to BF16 from checkpoint) +`CompressedTensorsW4A4Fp4.process_weights_after_loading()`: +```python +input_global_scale_inv = layer.input_scale.max() # = 0.00025141 (WRONG) +layer.input_global_scale = 1.0 / input_global_scale_inv # = 3977.6 +layer.input_global_scale_inv = input_global_scale_inv # = 0.00025141 +layer.alpha = input_global_scale * weight_global_scale +``` -### Why `wo_a` is safe as FP8 +At runtime: `scaled_fp4_quant(x, input_global_scale_inv=0.00025141)` divides by 0.00025141 → multiplies by 3977.6 → massive overflow → NaN. -`wo_a` uses `fp8_einsum` which does `output = fp8_act * fp8_weight * scale`. It's a **weight-only FP8** GEMM — no `input_scale` involved. The NaN comes from `scaled_fp4_quant(x, input_global_scale_inv)` in the other projections. +### Why our fixes didn't work -## Key evidence +| Attempt | Why it failed | +|---------|---------------| +| BF16 dequant + `UnquantizedLinearMethod` | `process_weights_after_loading` overwrites `quant_method` back to `FlashInferCutlassNvFp4LinearKernel` | +| Fix `input_scale` before quant method | Runs too early — quant method reads `input_scale` and overwrites our value | +| Fix `input_global_scale_inv` directly | Attribute doesn't exist yet when our code runs — it's set BY the quant method | -- `q_a_proj.input_scale = 0.00025141` → `1/input_scale = 3977.6` → quantizing activations with amax ~2-8 by 3977.6x = massive overflow -- `q_b_proj.input_scale = 0.00006140` → `1/input_scale = 16287.1` → even worse -- Embedding values: amax=1.27, std=0.09 — very small values that get multiplied by thousands during quantization +### The key insight -## Next steps +Our code runs **inside** `load_weights()`. The quant method's `process_weights_after_loading()` runs **after** `load_weights()` returns. Any changes we make get overwritten. -1. Check debug logs to see which projections were actually dequantized -2. If `fused_wqa_wkv` wasn't found, fix the attribute path -3. If it was found and dequantized, the NaN source is elsewhere (wo_b? wq_b? something else?) -4. Consider: maybe the NaN is from the **KV cache FP8 quantization** or the **RoPE** implementation +## Config values (corrected) -## Docker/Build Notes +| Parameter | Value | +|-----------|-------| +| head_dim | 512 (NOT 56) | +| num_attention_heads | 128 | +| num_key_value_heads | 1 | +| q_lora_rank | 1536 | +| qk_rope_head_dim | 64 | +| o_lora_rank | 1024 | +| hc_mult | 4 | +| n_routed_experts | 384 (48 per EP rank) | -- Build: `screen -dmS build bash -c './build_and_run.sh 2>&1 | tee build.log'` -- Currently using `--enforce-eager` + `CLAWMINE_DEBUG=1` for diagnostics -- Don't hit the API with enforce-eager — JIT spikes crash the container -- For real testing: use compilation-config `{"cudagraph_mode": "NONE", "custom_ops": ["all"]}` instead of enforce-eager +## Next step: Post-init hook + +The fix must run AFTER `process_weights_after_loading` and BEFORE the first inference. Options: + +**Option A: Override `input_global_scale_inv` post-init** +- Add a `_fix_nvfp4_activation_scales()` method +- Call it from the right hook point (after quant method setup, before inference) +- Compute correct `input_global_scale_inv` from BF16 warmup +- Override the Parameter on each attention module + +**Option B: Replace quant_method with UnquantizedLinearMethod post-init** +- After `process_weights_after_loading`, dequant weights to BF16 +- Swap `quant_method` on attention modules to `UnquantizedLinearMethod` +- This time the quant method won't overwrite us (it already ran) + +**Option C: Override the quant config to skip attention modules** +- Tell `CompressedTensorsW4A4Fp4` to skip attention projections +- Then dequantize to BF16 ourselves +- Cleanest but requires modifying the quant config + +Option B is most straightforward. The quant method already ran and set up its attributes. We can then come in and replace everything with BF16. + +## Architecture notes + +- Attention uses MLA (Multi-head Latent Attention) with 2-step Q projection (q_a → q_b) +- `fused_wqa_wkv` = MergedColumnParallelLinear(q_a + kv fused) +- `wo_a` = FP8 via fp8_einsum (no input_scale, weight-only) +- `wo_b` = standard ColumnParallelLinear +- `hc_pre` / `hc_post` = Head-Conditioned mixing (tilelang custom ops) +- Dummy run zeros attention output by design (`out.zero_(); return`) +- FlashMLA handles the actual MLA attention kernel diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 36d53978..6be2279a 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -2344,6 +2344,8 @@ class DeepseekV4ForCausalLM(nn.Module): self.model._convert_nvfp4_post_load() print(" Warming up tilelang kernels...", flush=True) self._warmup_tilelang() + print(" Registering post-quant-init fix...", flush=True) + self._register_post_quant_fix() print(" NVFP4 model ready ✓", flush=True) return loaded_params @@ -2410,3 +2412,62 @@ class DeepseekV4ForCausalLM(nn.Module): def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping() + def _register_post_quant_fix(self) -> None: + """Register a one-shot forward pre-hook that fixes attention NVFP4 + activation scales AFTER process_weights_after_loading has run. + + process_weights_after_loading (called by vLLM's model loader AFTER + load_weights returns) sets up FlashInferCutlassNvFp4LinearKernel + with broken input_global_scale_inv from the checkpoint. This hook + runs on the first forward call (guaranteed after all init) and: + 1. Dequantizes attention NVFP4 weights to BF16 + 2. Replaces quant_method with UnquantizedLinearMethod + 3. Removes itself (one-shot) + """ + import os + # Only needed for NVFP4 quantized models + quant_method_name = type(getattr(self.model.layers[0].attn.fused_wqa_wkv, 'quant_method', None)).__name__ + if 'NvFp4' not in quant_method_name and 'nvfp4' not in quant_method_name.lower(): + print(" No NVFP4 attention fix needed (quant_method={quant_method_name})", flush=True) + return + + def _fix_attn_nvfp4(module, args): + """One-shot hook: dequant attention NVFP4 to BF16 after quant init.""" + print(" [CLAWMINE] Running post-quant-init BF16 fix...", flush=True) + from vllm.model_executor.layers.linear import UnquantizedLinearMethod + + E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) + fixed = 0 + for layer_idx, layer in enumerate(module.model.layers): + attn = layer.attn + for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]: + if not hasattr(attn, proj_name): + continue + mod = getattr(attn, proj_name) + if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8): + continue + # Dequantize to BF16 + module.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT) + # Replace quant method (quant method already ran, won't overwrite again) + mod.quant_method = UnquantizedLinearMethod() + # Clean up NVFP4 attributes that might confuse forward + for attr in ("weight_scale", "weight_scale_2", "input_scale", + "input_global_scale", "input_global_scale_inv", + "weight_global_scale", "alpha", + "weight_scale_inv"): + if hasattr(mod, attr): + try: + delattr(mod, attr) + except (AttributeError, TypeError): + pass + fixed += 1 + + print(f" [CLAWMINE] Fixed {fixed} attention projections → BF16", flush=True) + # Remove this hook (one-shot) + module._post_quant_fix_handle.remove() + del module._post_quant_fix_handle + print(" [CLAWMINE] Post-quant-init fix done ✓", flush=True) + + handle = self.register_forward_pre_hook(_fix_attn_nvfp4) + self._post_quant_fix_handle = handle +