Add post-quant-init forward hook to fix attention NVFP4
The key insight: process_weights_after_loading runs AFTER load_weights and sets up FlashInferCutlassNvFp4LinearKernel with broken input_global_scale_inv. Any fix inside load_weights gets overwritten. Solution: register a one-shot forward pre-hook that runs on the first forward call (guaranteed after all init). It dequantizes attention NVFP4 weights to BF16 and replaces quant_method with UnquantizedLinearMethod. Since process_weights_after_loading already ran, our changes won't be overwritten. Standalone test confirmed: all attention weights produce valid non-NaN output when dequantized to BF16.
This commit is contained in:
120
CURRENT_BUG.md
120
CURRENT_BUG.md
@@ -1,58 +1,104 @@
|
||||
# Current Bug: vLLM produces NaN from layer 0
|
||||
# Current Bug: vLLM produces empty/garbage output
|
||||
|
||||
**Status:** Active debugging — BF16 dequant fix in progress
|
||||
**Status:** Weights confirmed good — bug is in vLLM's quant pipeline for attention
|
||||
**Date:** 2026-05-18
|
||||
|
||||
## Symptom
|
||||
- vLLM server starts, loads model, but every inference produces NaN logits
|
||||
- Diagnostic prints show **NaN from layer 0 onward** — no layer ever produces valid output
|
||||
- Empty content in chat completions, NaN in logprobs
|
||||
- vLLM server starts, loads model, processes requests (200 OK)
|
||||
- Chat completions return `content: ""` with `finish_reason: "length"`
|
||||
- 20 completion tokens generated but all produce empty/NaN logits
|
||||
- With enforce-eager + diagnostics: **NaN from layer 0 onward** on real requests
|
||||
|
||||
## Root Cause (in progress)
|
||||
## ✅ Confirmed: Weights produce valid output
|
||||
|
||||
**The attention NVFP4 linear layers produce NaN immediately.**
|
||||
Standalone test (`test_attn_moe_chain.py`) running directly on B200:
|
||||
|
||||
The attention projections go through vLLM's `FlashInferCutlassNvFp4LinearKernel` which uses checkpoint `input_scale` as the activation global scale for `scaled_fp4_quant()`. The checkpoint `input_scale` values are wrong for this use case, causing overflow → NaN.
|
||||
| Step | Operation | amax | NaN? |
|
||||
|------|-----------|------|------|
|
||||
| 1 | Embed tokens | 1.27 | No |
|
||||
| 2 | hc_mult expansion | 1.27 | No |
|
||||
| 3 | RMSNorm | 0.20 | No |
|
||||
| 4 | q_a_proj (NVFP4→BF16 dequant + matmul) | 0.50 | No |
|
||||
| 5 | kv_proj (NVFP4→BF16 dequant + matmul) | 1.30 | No |
|
||||
| 6 | q_norm + kv_norm | 0.11 / 1.87 | No |
|
||||
| 7 | q_b_proj (NVFP4→BF16 dequant + matmul) | 1.10 | No |
|
||||
| 8 | MoE CuTeDSL runner (with warmup gs) | cosine 0.988 | No |
|
||||
|
||||
### What we've tried
|
||||
**Every step produces valid, non-NaN, non-zero output.** The problem is NOT the weights.
|
||||
|
||||
1. ✅ **MoE kernel is NOT the problem** — `test_runner_vllm_style.py` with warmup gs gives cosine 0.988, no NaN
|
||||
2. ❌ **Dequant ALL attn projections to BF16** — crashed: `wo_a.weight_scale_inv` missing (fp8_einsum needs it)
|
||||
3. ❌ **Dequant all except wo_a (keep wo_a as FP8)** — still NaN from layer 0. `wq_a` and `wkv` don't exist as separate attrs — they're **fused as `fused_wqa_wkv`**
|
||||
4. ❌ **Changed to dequant `fused_wqa_wkv`** — still NaN from layer 0. Debug prints added to check if the attrs are actually found.
|
||||
## ❌ Root Cause: vLLM's `process_weights_after_loading` breaks attention
|
||||
|
||||
### Current theory
|
||||
### The timeline
|
||||
|
||||
The BF16 dequant code may not be finding `fused_wqa_wkv` on the attention module, so it silently skips the most important projection. Debug logging added in latest commit to verify.
|
||||
1. `load_weights()` → our `_convert_nvfp4_post_load()` runs
|
||||
2. `process_weights_after_loading()` → vLLM's quant method runs AFTER, **overwriting our fixes**
|
||||
3. `FlashInferCutlassNvFp4LinearKernel` gets set up with broken `input_global_scale_inv`
|
||||
|
||||
### Attention architecture (DeepSeek V4 MLA)
|
||||
### What the quant method does
|
||||
|
||||
- `fused_wqa_wkv` — MergedColumnParallelLinear (q_a + kv fused)
|
||||
- `wq_b` — ColumnParallelLinear (second Q projection after RoPE)
|
||||
- `wo_a` — ColumnParallelLinear (FP8 via fp8_einsum, weight-only, NO input_scale)
|
||||
- `wo_b` — ColumnParallelLinear (final output projection)
|
||||
- `compressor` — already handled (reconstructed to BF16 from checkpoint)
|
||||
`CompressedTensorsW4A4Fp4.process_weights_after_loading()`:
|
||||
```python
|
||||
input_global_scale_inv = layer.input_scale.max() # = 0.00025141 (WRONG)
|
||||
layer.input_global_scale = 1.0 / input_global_scale_inv # = 3977.6
|
||||
layer.input_global_scale_inv = input_global_scale_inv # = 0.00025141
|
||||
layer.alpha = input_global_scale * weight_global_scale
|
||||
```
|
||||
|
||||
### Why `wo_a` is safe as FP8
|
||||
At runtime: `scaled_fp4_quant(x, input_global_scale_inv=0.00025141)` divides by 0.00025141 → multiplies by 3977.6 → massive overflow → NaN.
|
||||
|
||||
`wo_a` uses `fp8_einsum` which does `output = fp8_act * fp8_weight * scale`. It's a **weight-only FP8** GEMM — no `input_scale` involved. The NaN comes from `scaled_fp4_quant(x, input_global_scale_inv)` in the other projections.
|
||||
### Why our fixes didn't work
|
||||
|
||||
## Key evidence
|
||||
| Attempt | Why it failed |
|
||||
|---------|---------------|
|
||||
| BF16 dequant + `UnquantizedLinearMethod` | `process_weights_after_loading` overwrites `quant_method` back to `FlashInferCutlassNvFp4LinearKernel` |
|
||||
| Fix `input_scale` before quant method | Runs too early — quant method reads `input_scale` and overwrites our value |
|
||||
| Fix `input_global_scale_inv` directly | Attribute doesn't exist yet when our code runs — it's set BY the quant method |
|
||||
|
||||
- `q_a_proj.input_scale = 0.00025141` → `1/input_scale = 3977.6` → quantizing activations with amax ~2-8 by 3977.6x = massive overflow
|
||||
- `q_b_proj.input_scale = 0.00006140` → `1/input_scale = 16287.1` → even worse
|
||||
- Embedding values: amax=1.27, std=0.09 — very small values that get multiplied by thousands during quantization
|
||||
### The key insight
|
||||
|
||||
## Next steps
|
||||
Our code runs **inside** `load_weights()`. The quant method's `process_weights_after_loading()` runs **after** `load_weights()` returns. Any changes we make get overwritten.
|
||||
|
||||
1. Check debug logs to see which projections were actually dequantized
|
||||
2. If `fused_wqa_wkv` wasn't found, fix the attribute path
|
||||
3. If it was found and dequantized, the NaN source is elsewhere (wo_b? wq_b? something else?)
|
||||
4. Consider: maybe the NaN is from the **KV cache FP8 quantization** or the **RoPE** implementation
|
||||
## Config values (corrected)
|
||||
|
||||
## Docker/Build Notes
|
||||
| Parameter | Value |
|
||||
|-----------|-------|
|
||||
| head_dim | 512 (NOT 56) |
|
||||
| num_attention_heads | 128 |
|
||||
| num_key_value_heads | 1 |
|
||||
| q_lora_rank | 1536 |
|
||||
| qk_rope_head_dim | 64 |
|
||||
| o_lora_rank | 1024 |
|
||||
| hc_mult | 4 |
|
||||
| n_routed_experts | 384 (48 per EP rank) |
|
||||
|
||||
- Build: `screen -dmS build bash -c './build_and_run.sh 2>&1 | tee build.log'`
|
||||
- Currently using `--enforce-eager` + `CLAWMINE_DEBUG=1` for diagnostics
|
||||
- Don't hit the API with enforce-eager — JIT spikes crash the container
|
||||
- For real testing: use compilation-config `{"cudagraph_mode": "NONE", "custom_ops": ["all"]}` instead of enforce-eager
|
||||
## Next step: Post-init hook
|
||||
|
||||
The fix must run AFTER `process_weights_after_loading` and BEFORE the first inference. Options:
|
||||
|
||||
**Option A: Override `input_global_scale_inv` post-init**
|
||||
- Add a `_fix_nvfp4_activation_scales()` method
|
||||
- Call it from the right hook point (after quant method setup, before inference)
|
||||
- Compute correct `input_global_scale_inv` from BF16 warmup
|
||||
- Override the Parameter on each attention module
|
||||
|
||||
**Option B: Replace quant_method with UnquantizedLinearMethod post-init**
|
||||
- After `process_weights_after_loading`, dequant weights to BF16
|
||||
- Swap `quant_method` on attention modules to `UnquantizedLinearMethod`
|
||||
- This time the quant method won't overwrite us (it already ran)
|
||||
|
||||
**Option C: Override the quant config to skip attention modules**
|
||||
- Tell `CompressedTensorsW4A4Fp4` to skip attention projections
|
||||
- Then dequantize to BF16 ourselves
|
||||
- Cleanest but requires modifying the quant config
|
||||
|
||||
Option B is most straightforward. The quant method already ran and set up its attributes. We can then come in and replace everything with BF16.
|
||||
|
||||
## Architecture notes
|
||||
|
||||
- Attention uses MLA (Multi-head Latent Attention) with 2-step Q projection (q_a → q_b)
|
||||
- `fused_wqa_wkv` = MergedColumnParallelLinear(q_a + kv fused)
|
||||
- `wo_a` = FP8 via fp8_einsum (no input_scale, weight-only)
|
||||
- `wo_b` = standard ColumnParallelLinear
|
||||
- `hc_pre` / `hc_post` = Head-Conditioned mixing (tilelang custom ops)
|
||||
- Dummy run zeros attention output by design (`out.zero_(); return`)
|
||||
- FlashMLA handles the actual MLA attention kernel
|
||||
|
||||
@@ -2344,6 +2344,8 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
self.model._convert_nvfp4_post_load()
|
||||
print(" Warming up tilelang kernels...", flush=True)
|
||||
self._warmup_tilelang()
|
||||
print(" Registering post-quant-init fix...", flush=True)
|
||||
self._register_post_quant_fix()
|
||||
print(" NVFP4 model ready ✓", flush=True)
|
||||
|
||||
return loaded_params
|
||||
@@ -2410,3 +2412,62 @@ class DeepseekV4ForCausalLM(nn.Module):
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
return self.model.get_expert_mapping()
|
||||
|
||||
def _register_post_quant_fix(self) -> None:
|
||||
"""Register a one-shot forward pre-hook that fixes attention NVFP4
|
||||
activation scales AFTER process_weights_after_loading has run.
|
||||
|
||||
process_weights_after_loading (called by vLLM's model loader AFTER
|
||||
load_weights returns) sets up FlashInferCutlassNvFp4LinearKernel
|
||||
with broken input_global_scale_inv from the checkpoint. This hook
|
||||
runs on the first forward call (guaranteed after all init) and:
|
||||
1. Dequantizes attention NVFP4 weights to BF16
|
||||
2. Replaces quant_method with UnquantizedLinearMethod
|
||||
3. Removes itself (one-shot)
|
||||
"""
|
||||
import os
|
||||
# Only needed for NVFP4 quantized models
|
||||
quant_method_name = type(getattr(self.model.layers[0].attn.fused_wqa_wkv, 'quant_method', None)).__name__
|
||||
if 'NvFp4' not in quant_method_name and 'nvfp4' not in quant_method_name.lower():
|
||||
print(" No NVFP4 attention fix needed (quant_method={quant_method_name})", flush=True)
|
||||
return
|
||||
|
||||
def _fix_attn_nvfp4(module, args):
|
||||
"""One-shot hook: dequant attention NVFP4 to BF16 after quant init."""
|
||||
print(" [CLAWMINE] Running post-quant-init BF16 fix...", flush=True)
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
|
||||
E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16)
|
||||
fixed = 0
|
||||
for layer_idx, layer in enumerate(module.model.layers):
|
||||
attn = layer.attn
|
||||
for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]:
|
||||
if not hasattr(attn, proj_name):
|
||||
continue
|
||||
mod = getattr(attn, proj_name)
|
||||
if not hasattr(mod, "weight") or mod.weight.dtype not in (torch.uint8, torch.int8):
|
||||
continue
|
||||
# Dequantize to BF16
|
||||
module.model._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
|
||||
# Replace quant method (quant method already ran, won't overwrite again)
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
# Clean up NVFP4 attributes that might confuse forward
|
||||
for attr in ("weight_scale", "weight_scale_2", "input_scale",
|
||||
"input_global_scale", "input_global_scale_inv",
|
||||
"weight_global_scale", "alpha",
|
||||
"weight_scale_inv"):
|
||||
if hasattr(mod, attr):
|
||||
try:
|
||||
delattr(mod, attr)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
fixed += 1
|
||||
|
||||
print(f" [CLAWMINE] Fixed {fixed} attention projections → BF16", flush=True)
|
||||
# Remove this hook (one-shot)
|
||||
module._post_quant_fix_handle.remove()
|
||||
del module._post_quant_fix_handle
|
||||
print(" [CLAWMINE] Post-quant-init fix done ✓", flush=True)
|
||||
|
||||
handle = self.register_forward_pre_hook(_fix_attn_nvfp4)
|
||||
self._post_quant_fix_handle = handle
|
||||
|
||||
|
||||
Reference in New Issue
Block a user