diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 481b67fc..01e618f2 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -1685,26 +1685,37 @@ class DeepseekV4Model(nn.Module): def _convert_nvfp4_post_load(self): """Post-load conversion of NVFP4 weights for vLLM compatibility. - Fixes the attention input_global_scale_inv (activation global scale) - by running a warmup forward and computing the correct scale from - actual activation magnitudes. The checkpoint input_scale values are - calibrated incorrectly and cause NaN during activation quantization. + All attention NVFP4 projections (except wo_a) are dequantized to BF16. + The checkpoint input_scale values cause NaN during activation quantization + in FlashInferCutlassNvFp4LinearKernel. BF16 bypasses this entirely. wo_a is converted to FP8 for fp8_einsum (no input_scale needed). Compressor weights are reconstructed from checkpoint sub-weights. """ - # wo_a → FP8 (fp8_einsum path, no input_scale) + bf16_proj_names = {"fused_wqa_wkv", "wq_b", "wo_b"} fp8_proj_names = {"wo_a"} + bf16_converted = 0 fp8_converted = 0 compressor_converted = 0 - input_scale_fixes = 0 _shard_index = self._build_shard_index("/model") if os.path.isdir("/model") else None from tqdm import tqdm - for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (fix)NVFP4 attn input_scale", unit="layer"): + for layer_idx, layer in tqdm(enumerate(self.layers), total=len(self.layers), desc=" (upcast)NVFP4→BF16 attn projs", unit="layer"): attn = layer.attn + # BF16 dequantization: attention projections (except wo_a) + for proj_name in bf16_proj_names: + if not hasattr(attn, proj_name): + continue + mod = getattr(attn, proj_name) + if not hasattr(mod, "weight"): + continue + if mod.weight.dtype in (torch.uint8, torch.int8): + E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) + self._dequant_nvfp4_to_bf16(mod, E2M1_LUT) + bf16_converted += 1 + # FP8 conversion: wo_a (used by fp8_einsum, no input_scale) FP8_MAX = torch.finfo(torch.float8_e4m3fn).max for proj_name in fp8_proj_names: @@ -1718,68 +1729,6 @@ class DeepseekV4Model(nn.Module): self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX) fp8_converted += 1 - # Fix input_global_scale_inv for NVFP4 attention projections - # The checkpoint input_scale is wrong. We compute the correct scale - # by dequantizing to BF16 temporarily and running a warmup. - for proj_name in ["fused_wqa_wkv", "wq_b", "wo_b"]: - if not hasattr(attn, proj_name): - if layer_idx == 0: - print(f"[CLAWMINE] Layer 0: {proj_name} NOT on attn") - continue - mod = getattr(attn, proj_name) - if layer_idx == 0: - print(f"[CLAWMINE] Layer 0: {proj_name} dtype={mod.weight.dtype} has_input_global_scale_inv={hasattr(mod, 'input_global_scale_inv')} has_input_scale={hasattr(mod, 'input_scale')}") - if not hasattr(mod, "input_global_scale_inv"): - continue - if mod.weight.dtype not in (torch.uint8, torch.int8): - continue - - # Temporarily dequantize weight to BF16 for warmup - E2M1_LUT = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16) - w_uint8 = mod.weight.data - w_bf16_unpacked = self._unpack_nvfp4_to_bf16(w_uint8, E2M1_LUT, w_uint8.device) - if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - block_scale = self._block_scale_to_float32(mod.weight_scale.data) - if block_scale.dim() == 2 and w_bf16_unpacked.dim() == 2: - block_size = w_bf16_unpacked.shape[1] // block_scale.shape[1] - block_scale_expanded = block_scale.unsqueeze(-1).expand(-1, -1, block_size).reshape(w_bf16_unpacked.shape) - else: - block_scale_expanded = block_scale - global_scale = mod.weight_scale_2.data.max().item() - w_bf16_dequant = (w_bf16_unpacked.float() * block_scale_expanded * global_scale).to(torch.bfloat16) - else: - w_bf16_dequant = w_bf16_unpacked - - # Warmup: compute actual activation amax using BF16 reference - with torch.no_grad(): - in_features = w_bf16_dequant.shape[-1] - dummy_input = torch.randn(256, in_features, dtype=torch.bfloat16, device=mod.weight.device) * 2.0 - ref_output = torch.nn.functional.linear(dummy_input, w_bf16_dequant) - act_amax = ref_output.amax().item() - # Clean up temp tensors - del w_bf16_unpacked, w_bf16_dequant, ref_output - - # Set correct input_global_scale_inv: 1/(amax * headroom) - # scaled_fp4_quant divides by input_global_scale_inv - # so input_global_scale_inv should be ~ amax (to map amax → 1.0 in FP4) - headroom = 1.2 # slight headroom to avoid clipping - new_inv = act_amax * headroom if act_amax > 0 else 1.0 - new_scale = 1.0 / new_inv - - if layer_idx == 0: - old_inv = mod.input_global_scale_inv.item() if hasattr(mod.input_global_scale_inv, 'item') else float(mod.input_global_scale_inv) - old_scale = mod.input_global_scale.item() if hasattr(mod.input_global_scale, 'item') else float(mod.input_global_scale) - print(f"[CLAWMINE] Layer 0: {proj_name} scale_inv: {old_inv:.8f}→{new_inv:.8f} scale: {old_scale:.8f}→{new_scale:.8f} (act_amax={act_amax:.4f})") - - mod.input_global_scale = torch.nn.Parameter(torch.tensor(new_scale, dtype=torch.float32), requires_grad=False) - mod.input_global_scale_inv = torch.nn.Parameter(torch.tensor(new_inv, dtype=torch.float32), requires_grad=False) - # Update alpha: input_scale * weight_scale (both are the "1/x" form now) - if hasattr(mod, "weight_global_scale") and hasattr(mod, "alpha"): - wgs = mod.weight_global_scale.item() if hasattr(mod.weight_global_scale, 'item') else float(mod.weight_global_scale) - mod.alpha = torch.nn.Parameter(torch.tensor(new_scale * wgs, dtype=torch.float32), requires_grad=False) - - input_scale_fixes += 1 - # Compressor: still needs BF16 reconstruction mla_attn = getattr(attn, "mla_attn", None) if mla_attn is not None: