From f86892e26b2bf15e941152f09efd8d2f3669e36b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 18 May 2026 15:43:46 +0000 Subject: [PATCH] Replace BF16 dequant with input_scale warmup fix for attention NVFP4 Instead of dequantizing attention weights to BF16 (which had issues with MergedColumnParallelLinear and different weight_scale_2 values), keep the NVFP4 path but fix the activation global scale. Compute correct input_global_scale_inv by: 1. Temporarily dequantizing weight to BF16 2. Running warmup forward with random input 3. Computing actual activation amax 4. Setting scale_inv = amax * headroom This preserves the original NVFP4 quantization pipeline. --- vllm/patches/deepseek_v4.py | 99 +++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 32 deletions(-) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 6a0c4d5c..c1b7cf01 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -1685,49 +1685,26 @@ class DeepseekV4Model(nn.Module): def _convert_nvfp4_post_load(self): """Post-load conversion of NVFP4 weights for vLLM compatibility. - All attention NVFP4 projections are dequantized to BF16 because - the checkpoint input_scale values cause NaN during activation - quantization in FlashInferCutlassNvFp4LinearKernel. BF16 bypasses - the broken input_scale entirely. + Fixes the attention input_global_scale_inv (activation global scale) + by running a warmup forward and computing the correct scale from + actual activation magnitudes. The checkpoint input_scale values are + calibrated incorrectly and cause NaN during activation quantization. - Compressor weights are reconstructed from checkpoint sub-weights - because the stacking weight_loader corrupts NVFP4 uint8 data. + wo_a is converted to FP8 for fp8_einsum (no input_scale needed). + Compressor weights are reconstructed from checkpoint sub-weights. """ - # All attention projections to dequantize to BF16 - # wo_a is excluded — it uses fp8_einsum (no input_scale, weight-only FP8) - # wq_a and wkv are fused into fused_wqa_wkv - bf16_proj_names = {"wq_b", "wo_b", "fused_wqa_wkv"} + # wo_a → FP8 (fp8_einsum path, no input_scale) fp8_proj_names = {"wo_a"} - bf16_converted = 0 fp8_converted = 0 compressor_converted = 0 + input_scale_fixes = 0 _shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None from tqdm import tqdm - for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (upcast)NVFP4→BF16 attn projs", unit="layer"): + for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (fix)NVFP4 attn input_scale", unit="layer"): attn = layer.attn - # BF16 dequantization: attention projections (except wo_a) - for proj_name in bf16_proj_names: - if not hasattr(attn, proj_name): - if layer_idx == 0: - print(f"[CLAWMINE] Layer 0: {proj_name} NOT FOUND on attn (type={type(attn).__name__})") - continue - mod = getattr(attn, proj_name) - if not hasattr(mod, "weight"): - if layer_idx == 0: - print(f"[CLAWMINE] Layer 0: {proj_name} has no weight attr") - continue - if layer_idx == 0: - print(f"[CLAWMINE] Layer 0: {proj_name} weight dtype={mod.weight.dtype}") - if mod.weight.dtype in (torch.uint8, torch.int8): - E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) - self._dequant_nvfp4_to_bf16(mod, E2M1_LUT) - if layer_idx == 0: - print(f"[CLAWMINE] Layer 0: {proj_name} AFTER dequant: dtype={mod.weight.dtype} amax={mod.weight.data.amax().item():.4f} NaN={torch.isnan(mod.weight.data).any().item()} quant_method={type(mod.quant_method).__name__}") - bf16_converted += 1 - # FP8 conversion: wo_a (used by fp8_einsum, no input_scale) FP8_MAX = torch.finfo(torch.float8_e4m3fn).max for proj_name in fp8_proj_names: @@ -1741,6 +1718,64 @@ class DeepseekV4Model(nn.Module): self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX) fp8_converted += 1 + # Fix input_global_scale_inv for NVFP4 attention projections + # The checkpoint input_scale is wrong. We compute the correct scale + # by dequantizing to BF16 temporarily and running a warmup. + for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]: + if not hasattr(attn, proj_name): + continue + mod = getattr(attn, proj_name) + if not hasattr(mod, "input_global_scale_inv"): + continue + if mod.weight.dtype not in (torch.uint8, torch.int8): + continue + + # Temporarily dequantize weight to BF16 for warmup + E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) + w_uint8 = mod.weight.data + w_bf16_unpacked = self._unpack_nvfp4_to_bf16(w_uint8, E2M1_LUT, w_uint8.device) + if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): + block_scale = self._block_scale_to_float32(mod.weight_scale.data) + if block_scale.dim() == 2 and w_bf16_unpacked.dim() == 2: + block_size = w_bf16_unpacked.shape[1] // block_scale.shape[1] + block_scale_expanded = block_scale.unsqueeze(-1).expand(-1, -1, block_size).reshape(w_bf16_unpacked.shape) + else: + block_scale_expanded = block_scale + global_scale = mod.weight_scale_2.data.max().item() + w_bf16_dequant = (w_bf16_unpacked.float() * block_scale_expanded * global_scale).to(torch.bfloat16) + else: + w_bf16_dequant = w_bf16_unpacked + + # Warmup: compute actual activation amax using BF16 reference + with torch.no_grad(): + in_features = w_bf16_dequant.shape[-1] + dummy_input = torch.randn(256, in_features, dtype=torch.bfloat16, device=mod.weight.device) * 2.0 + ref_output = torch.nn.functional.linear(dummy_input, w_bf16_dequant) + act_amax = ref_output.amax().item() + # Clean up temp tensors + del w_bf16_unpacked, w_bf16_dequant, ref_output + + # Set correct input_global_scale_inv: 1/(amax * headroom) + # scaled_fp4_quant divides by input_global_scale_inv + # so input_global_scale_inv should be ~ amax (to map amax → 1.0 in FP4) + headroom = 1.2 # slight headroom to avoid clipping + new_inv = act_amax * headroom if act_amax > 0 else 1.0 + new_scale = 1.0 / new_inv + + if layer_idx == 0: + old_inv = mod.input_global_scale_inv.item() if hasattr(mod.input_global_scale_inv, 'item') else float(mod.input_global_scale_inv) + old_scale = mod.input_global_scale.item() if hasattr(mod.input_global_scale, 'item') else float(mod.input_global_scale) + print(f"[CLAWMINE] Layer 0: {proj_name} scale_inv: {old_inv:.8f}→{new_inv:.8f} scale: {old_scale:.8f}→{new_scale:.8f} (act_amax={act_amax:.4f})") + + mod.input_global_scale = torch.nn.Parameter(torch.tensor(new_scale, dtype=torch.float32), requires_grad=False) + mod.input_global_scale_inv = torch.nn.Parameter(torch.tensor(new_inv, dtype=torch.float32), requires_grad=False) + # Update alpha: input_scale * weight_scale (both are the "1/x" form now) + if hasattr(mod, "weight_global_scale") and hasattr(mod, "alpha"): + wgs = mod.weight_global_scale.item() if hasattr(mod.weight_global_scale, 'item') else float(mod.weight_global_scale) + mod.alpha = torch.nn.Parameter(torch.tensor(new_scale * wgs, dtype=torch.float32), requires_grad=False) + + input_scale_fixes += 1 + # Compressor: still needs BF16 reconstruction mla_attn = getattr(attn, "mla_attn", None) if mla_attn is not None: