#!/usr/bin/python3 """ Clean rewrite of the NVFP4→FP8/bf16 conversion. Strategy: - wo_a, fused_wqa_wkv → FP8 (used with fp8_einsum, need weight_scale_inv) - wq_b, wo_b, gate_up_proj → bf16 (used via .forward(), just works) - compressor fused_wkv_wgate → bf16 (already handled in load path) - MoE experts → native NVFP4 (ModelOptNvFp4FusedMoE handles it) """ filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() # Find and replace the entire _convert_nvfp4_attention_to_fp8 method # and _convert_nvfp4_module_to_fp8 method import re # Remove old methods and insert new ones # Find the method definitions m1_start = c.find(" def _convert_nvfp4_attention_to_fp8(self):") if m1_start < 0: print("ERROR: Could not find _convert_nvfp4_attention_to_fp8") exit(1) # Find the end: look for the next method/class at the same or lower indent # after _convert_nvfp4_module_to_fp8 m2_start = c.find(" def _convert_nvfp4_module_to_fp8(self, mod, e2m1_lut, fp8_max):", m1_start) if m2_start < 0: print("ERROR: Could not find _convert_nvfp4_module_to_fp8") exit(1) # Find the end of the second method # Scan for the next line at indent <= 4 that's not blank pos = m2_start lines_after = c[m2_start:].split('\n') end_line = 0 for i, line in enumerate(lines_after[1:], 1): if line.strip() == '': continue indent = len(line) - len(line.lstrip()) if indent <= 4: end_line = i break # Calculate the end position end_pos = m2_start + sum(len(l) + 1 for l in lines_after[:end_line]) new_methods = ''' def _convert_nvfp4_post_load(self): """Post-load conversion of NVFP4 weights for vLLM compatibility. Strategy: - wo_a, fused_wqa_wkv: Convert NVFP4->FP8 (used with fp8_einsum) - wq_b, wo_b, gate_up_proj: Dequant NVFP4->bf16 (used via .forward()) - MoE experts: Stay in native NVFP4 (ModelOptNvFp4FusedMoE) """ E2M1_LUT = torch.tensor( [0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16 ) FP8_MAX = torch.finfo(torch.float8_e4m3fn).max # Layers that use fp8_einsum (need FP8 + weight_scale_inv) fp8_proj_names = {"fused_wqa_wkv", "wo_a"} # Layers that use normal .forward() (need bf16) bf16_proj_names = {"wq_b", "wo_b"} bf16_shared_names = {"gate_up_proj"} fp8_converted = 0 bf16_converted = 0 for layer_idx, layer in enumerate(self.layers): attn = layer.attn for proj_name in fp8_proj_names: if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX) fp8_converted += 1 for proj_name in bf16_proj_names: if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._dequant_nvfp4_to_bf16(mod, E2M1_LUT) bf16_converted += 1 # Shared experts ffn = layer.ffn if hasattr(ffn, "shared_experts"): for proj_name in bf16_shared_names: if not hasattr(ffn.shared_experts, proj_name): continue mod = getattr(ffn.shared_experts, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._dequant_nvfp4_to_bf16(mod, E2M1_LUT) bf16_converted += 1 if fp8_converted > 0 or bf16_converted > 0: print(f"NVFP4 post-load: {fp8_converted} layers -> FP8, " f"{bf16_converted} layers -> bf16, MoE experts stay NVFP4") def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut): """Dequantize NVFP4 weight to bf16 for normal .forward() path.""" w_uint8 = mod.weight.data device = w_uint8.device w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device) # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): block_scale = mod.weight_scale.data.to(torch.float32) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( -1, -1, block_size ).reshape(w_bf16.shape) else: block_scale_expanded = block_scale global_scale = mod.weight_scale_2.data.max().item() input_scale = ( mod.input_scale.data.max().item() if hasattr(mod, "input_scale") else 1.0 ) w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale w_dequant = w_dequant.to(torch.bfloat16) else: w_dequant = w_bf16 # Replace weight with bf16 version mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False) from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale", "weight_scale_inv"): if hasattr(mod, attr): delattr(mod, attr) def _convert_nvfp4_to_fp8(self, mod, e2m1_lut, fp8_max): """Convert NVFP4 weight to FP8 for fp8_einsum path.""" w_uint8 = mod.weight.data device = w_uint8.device w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device) # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): block_scale = mod.weight_scale.data.to(torch.float32) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( -1, -1, block_size ).reshape(w_bf16.shape) else: block_scale_expanded = block_scale global_scale = mod.weight_scale_2.data.max().item() input_scale = ( mod.input_scale.data.max().item() if hasattr(mod, "input_scale") else 1.0 ) w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale w_dequant = w_dequant.to(torch.bfloat16) else: w_dequant = w_bf16 # Re-quantize bf16 -> FP8 e4m3 w_amax = w_dequant.abs().amax() if w_amax == 0: w_amax = torch.tensor(1.0, device=device) fp8_scale = w_amax / fp8_max w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn) weight_scale_inv = fp8_scale.to(torch.float32) mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) mod.weight_scale_inv = torch.nn.Parameter( weight_scale_inv.reshape(1), requires_grad=False ) from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale"): if hasattr(mod, attr): delattr(mod, attr) def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device): """Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format.""" # Extract 4-bit FP4 values (0-15, bit 3 = sign) even_raw = (w_uint8 & 0x0F).int() odd_raw = ((w_uint8 >> 4) & 0x0F).int() # Sign: 0-7 = positive, 8-15 = negative even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16) odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16) # Magnitude index: lower 3 bits (0-7) even_vals = even_sign * e2m1_lut.to(device)[even_raw & 0x07] odd_vals = odd_sign * e2m1_lut.to(device)[odd_raw & 0x07] # Interleave and flatten w_bf16 = torch.stack([even_vals, odd_vals], dim=-1) w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16) return w_bf16 ''' c = c[:m1_start] + new_methods + c[end_pos:] # Also update the call from DeepseekV4ForCausalLM.load_weights c = c.replace( "self.model._convert_nvfp4_attention_to_fp8()", "self.model._convert_nvfp4_post_load()" ) with open(filepath, 'w') as f: f.write(c) import ast try: ast.parse(c) print("Syntax OK") except SyntaxError as e: print(f"Syntax error at line {e.lineno}: {e.msg}") print("Replaced conversion methods with clean FP8/bf16 split")