Add post-quant-init forward hook to fix attention NVFP4

The key insight: process_weights_after_loading runs AFTER load_weights
and sets up FlashInferCutlassNvFp4LinearKernel with broken
input_global_scale_inv. Any fix inside load_weights gets overwritten.

Solution: register a one-shot forward pre-hook that runs on the first
forward call (guaranteed after all init). It dequantizes attention
NVFP4 weights to BF16 and replaces quant_method with
UnquantizedLinearMethod. Since process_weights_after_loading already
ran, our changes won't be overwritten.

Standalone test confirmed: all attention weights produce valid
non-NaN output when dequantized to BF16.
This commit is contained in:
2026-05-18 17:56:19 +00:00
parent 2835cb040b
commit a51edd238e
2 changed files with 144 additions and 37 deletions

View File

@@ -1,58 +1,104 @@
# Current Bug: vLLM produces NaN from layer 0
# Current Bug: vLLM produces empty/garbage output
**Status:** Active debugging — BF16 dequant fix in progress
**Status:** Weights confirmed good — bug is in vLLM's quant pipeline for attention
**Date:** 2026-05-18
## Symptom
- vLLM server starts, loads model, but every inference produces NaN logits
- Diagnostic prints show **NaN from layer 0 onward** — no layer ever produces valid output
- Empty content in chat completions, NaN in logprobs
- vLLM server starts, loads model, processes requests (200 OK)
- Chat completions return `content: ""` with `finish_reason: "length"`
- 20 completion tokens generated but all produce empty/NaN logits
- With enforce-eager + diagnostics: **NaN from layer 0 onward** on real requests
## Root Cause (in progress)
## ✅ Confirmed: Weights produce valid output
**The attention NVFP4 linear layers produce NaN immediately.**
Standalone test (`test_attn_moe_chain.py`) running directly on B200:
The attention projections go through vLLM's `FlashInferCutlassNvFp4LinearKernel` which uses checkpoint `input_scale` as the activation global scale for `scaled_fp4_quant()`. The checkpoint `input_scale` values are wrong for this use case, causing overflow → NaN.
| Step | Operation | amax | NaN? |
|------|-----------|------|------|
| 1 | Embed tokens | 1.27 | No |
| 2 | hc_mult expansion | 1.27 | No |
| 3 | RMSNorm | 0.20 | No |
| 4 | q_a_proj (NVFP4→BF16 dequant + matmul) | 0.50 | No |
| 5 | kv_proj (NVFP4→BF16 dequant + matmul) | 1.30 | No |
| 6 | q_norm + kv_norm | 0.11 / 1.87 | No |
| 7 | q_b_proj (NVFP4→BF16 dequant + matmul) | 1.10 | No |
| 8 | MoE CuTeDSL runner (with warmup gs) | cosine 0.988 | No |
### What we've tried
**Every step produces valid, non-NaN, non-zero output.** The problem is NOT the weights.
1.**MoE kernel is NOT the problem**`test_runner_vllm_style.py` with warmup gs gives cosine 0.988, no NaN
2.**Dequant ALL attn projections to BF16** — crashed: `wo_a.weight_scale_inv` missing (fp8_einsum needs it)
3.**Dequant all except wo_a (keep wo_a as FP8)** — still NaN from layer 0. `wq_a` and `wkv` don't exist as separate attrs — they're **fused as `fused_wqa_wkv`**
4.**Changed to dequant `fused_wqa_wkv`** — still NaN from layer 0. Debug prints added to check if the attrs are actually found.
## ❌ Root Cause: vLLM's `process_weights_after_loading` breaks attention
### Current theory
### The timeline
The BF16 dequant code may not be finding `fused_wqa_wkv` on the attention module, so it silently skips the most important projection. Debug logging added in latest commit to verify.
1. `load_weights()` → our `_convert_nvfp4_post_load()` runs
2. `process_weights_after_loading()` → vLLM's quant method runs AFTER, **overwriting our fixes**
3. `FlashInferCutlassNvFp4LinearKernel` gets set up with broken `input_global_scale_inv`
### Attention architecture (DeepSeek V4 MLA)
### What the quant method does
- `fused_wqa_wkv` — MergedColumnParallelLinear (q_a + kv fused)
- `wq_b` — ColumnParallelLinear (second Q projection after RoPE)
- `wo_a` — ColumnParallelLinear (FP8 via fp8_einsum, weight-only, NO input_scale)
- `wo_b` — ColumnParallelLinear (final output projection)
- `compressor` — already handled (reconstructed to BF16 from checkpoint)
`CompressedTensorsW4A4Fp4.process_weights_after_loading()`:
```python
input_global_scale_inv = layer.input_scale.max() # = 0.00025141 (WRONG)
layer.input_global_scale = 1.0 / input_global_scale_inv # = 3977.6
layer.input_global_scale_inv = input_global_scale_inv # = 0.00025141
layer.alpha = input_global_scale * weight_global_scale
```
### Why `wo_a` is safe as FP8
At runtime: `scaled_fp4_quant(x, input_global_scale_inv=0.00025141)` divides by 0.00025141 → multiplies by 3977.6 → massive overflow → NaN.
`wo_a` uses `fp8_einsum` which does `output = fp8_act * fp8_weight * scale`. It's a **weight-only FP8** GEMM — no `input_scale` involved. The NaN comes from `scaled_fp4_quant(x, input_global_scale_inv)` in the other projections.
### Why our fixes didn't work
## Key evidence
| Attempt | Why it failed |
|---------|---------------|
| BF16 dequant + `UnquantizedLinearMethod` | `process_weights_after_loading` overwrites `quant_method` back to `FlashInferCutlassNvFp4LinearKernel` |
| Fix `input_scale` before quant method | Runs too early — quant method reads `input_scale` and overwrites our value |
| Fix `input_global_scale_inv` directly | Attribute doesn't exist yet when our code runs — it's set BY the quant method |
- `q_a_proj.input_scale = 0.00025141``1/input_scale = 3977.6` → quantizing activations with amax ~2-8 by 3977.6x = massive overflow
- `q_b_proj.input_scale = 0.00006140``1/input_scale = 16287.1` → even worse
- Embedding values: amax=1.27, std=0.09 — very small values that get multiplied by thousands during quantization
### The key insight
## Next steps
Our code runs **inside** `load_weights()`. The quant method's `process_weights_after_loading()` runs **after** `load_weights()` returns. Any changes we make get overwritten.
1. Check debug logs to see which projections were actually dequantized
2. If `fused_wqa_wkv` wasn't found, fix the attribute path
3. If it was found and dequantized, the NaN source is elsewhere (wo_b? wq_b? something else?)
4. Consider: maybe the NaN is from the **KV cache FP8 quantization** or the **RoPE** implementation
## Config values (corrected)
## Docker/Build Notes
| Parameter | Value |
|-----------|-------|
| head_dim | 512 (NOT 56) |
| num_attention_heads | 128 |
| num_key_value_heads | 1 |
| q_lora_rank | 1536 |
| qk_rope_head_dim | 64 |
| o_lora_rank | 1024 |
| hc_mult | 4 |
| n_routed_experts | 384 (48 per EP rank) |
- Build: `screen -dmS build bash -c './build_and_run.sh 2>&1 | tee build.log'`
- Currently using `--enforce-eager` + `CLAWMINE_DEBUG=1` for diagnostics
- Don't hit the API with enforce-eager — JIT spikes crash the container
- For real testing: use compilation-config `{"cudagraph_mode": "NONE", "custom_ops": ["all"]}` instead of enforce-eager
## Next step: Post-init hook
The fix must run AFTER `process_weights_after_loading` and BEFORE the first inference. Options:
**Option A: Override `input_global_scale_inv` post-init**
- Add a `_fix_nvfp4_activation_scales()` method
- Call it from the right hook point (after quant method setup, before inference)
- Compute correct `input_global_scale_inv` from BF16 warmup
- Override the Parameter on each attention module
**Option B: Replace quant_method with UnquantizedLinearMethod post-init**
- After `process_weights_after_loading`, dequant weights to BF16
- Swap `quant_method` on attention modules to `UnquantizedLinearMethod`
- This time the quant method won't overwrite us (it already ran)
**Option C: Override the quant config to skip attention modules**
- Tell `CompressedTensorsW4A4Fp4` to skip attention projections
- Then dequantize to BF16 ourselves
- Cleanest but requires modifying the quant config
Option B is most straightforward. The quant method already ran and set up its attributes. We can then come in and replace everything with BF16.
## Architecture notes
- Attention uses MLA (Multi-head Latent Attention) with 2-step Q projection (q_a → q_b)
- `fused_wqa_wkv` = MergedColumnParallelLinear(q_a + kv fused)
- `wo_a` = FP8 via fp8_einsum (no input_scale, weight-only)
- `wo_b` = standard ColumnParallelLinear
- `hc_pre` / `hc_post` = Head-Conditioned mixing (tilelang custom ops)
- Dummy run zeros attention output by design (`out.zero_(); return`)
- FlashMLA handles the actual MLA attention kernel

View File

@@ -2344,6 +2344,8 @@ class DeepseekV4ForCausalLM(nn.Module):
self.model._convert_nvfp4_post_load()
print(" Warming up tilelang kernels...", flush=True)
self._warmup_tilelang()
print(" Registering post-quant-init fix...", flush=True)
self._register_post_quant_fix()
print(" NVFP4 model ready ✓", flush=True)
return loaded_params
@@ -2410,3 +2412,62 @@ class DeepseekV4ForCausalLM(nn.Module):
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
def _register_post_quant_fix(self) -> None:
"""Register a one-shot forward pre-hook that fixes attention NVFP4
activation scales AFTER process_weights_after_loading has run.
process_weights_after_loading (called by vLLM's model loader AFTER
load_weights returns) sets up FlashInferCutlassNvFp4LinearKernel
with broken input_global_scale_inv from the checkpoint. This hook
runs on the first forward call (guaranteed after all init) and:
1. Dequantizes attention NVFP4 weights to BF16
2. Replaces quant_method with UnquantizedLinearMethod
3. Removes itself (one-shot)
"""
import os
# Only needed for NVFP4 quantized models
quant_method_name = type(getattr(self.model.layers[0].attn.fused_wqa_wkv, 'quant_method', None)).__name__
if 'NvFp4' not in quant_method_name and 'nvfp4' not in quant_method_name.lower():
print(" No NVFP4 attention fix needed (quant_method={quant_method_name})", flush=True)
return
def _fix_attn_nvfp4(module, args):
"""One-shot hook: dequant attention NVFP4 to BF16 after quant init."""
print(" [CLAWMINE] Running post-quant-init BF16 fix...", flush=True)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
fixed = 0
for layer_idx, layer in enumerate(module.model.layers):
attn = layer.attn
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
continue
# Dequantize to BF16
module.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
# Replace quant method (quant method already ran, won't overwrite again)
mod.quant_method = UnquantizedLinearMethod()
# Clean up NVFP4 attributes that might confuse forward
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"input_global_scale", "input_global_scale_inv",
"weight_global_scale", "alpha",
"weight_scale_inv"):
if hasattr(mod, attr):
try:
delattr(mod, attr)
except (AttributeError, TypeError):
pass
fixed += 1
print(f" [CLAWMINE] Fixed {fixed} attention projections → BF16", flush=True)
# Remove this hook (one-shot)
module._post_quant_fix_handle.remove()
del module._post_quant_fix_handle
print(" [CLAWMINE] Post-quant-init fix done ✓", flush=True)
handle = self.register_forward_pre_hook(_fix_attn_nvfp4)
self._post_quant_fix_handle = handle