#!/usr/bin/python3 """Replace the bf16->NVFP4 quantization handler with a simple bf16->FP8 conversion. wo_a is used with fp8_einsum, so it needs FP8 weight + weight_scale_inv.""" filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() # Find and replace the entire bf16->uint8 handler block # It starts with the o_a_proj comment and ends with "continue" import re # Find the handler m = re.search( r"( # Handle o_a_proj bf16 → wo_a uint8 mismatch:.*?)(\n continue\n)", c, re.DOTALL, ) if m: # Replace with bf16->FP8 conversion new_handler = ''' # Handle o_a_proj bf16 -> wo_a: convert to FP8 directly # (wo_a is used with fp8_einsum, needs FP8 + weight_scale_inv) if (name.endswith(".weight") and loaded_weight.dtype != torch.uint8 and param.data.dtype == torch.uint8): w_bf16 = loaded_weight w_amax = w_bf16.abs().amax() if w_amax == 0: w_amax = torch.tensor(1.0, device=w_bf16.device) fp8_max = torch.finfo(torch.float8_e4m3fn).max fp8_scale = w_amax / fp8_max w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn) weight_scale_inv = fp8_scale.to(torch.float32) parts = name.rsplit(".", 1) module_path = parts[0] mod = self for attr in module_path.split("."): if attr.isdigit(): mod = mod[int(attr)] else: mod = getattr(mod, attr) mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) mod.weight_scale_inv = torch.nn.Parameter( weight_scale_inv.reshape(1), requires_grad=False ) from vllm.model_executor.layers.linear import ( UnquantizedLinearMethod, ) mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale"): if hasattr(mod, attr): delattr(mod, attr) loaded_params.add(name) loaded_params.add(name.replace(".weight", ".weight_scale_inv")) continue ''' c = c[:m.start()] + new_handler + c[m.end():] print("Replaced bf16->NVFP4 handler with bf16->FP8 handler") else: print("Could not find handler block, trying alternate search...") # Try finding just the if condition idx = c.find("and loaded_weight.dtype != torch.uint8\n") if idx > 0: print(f"Found condition at position {idx}") else: print("ERROR: Could not find condition") with open(filepath, 'w') as f: f.write(c) import ast try: ast.parse(c) print("Syntax OK") except SyntaxError as e: print(f"Syntax error at line {e.lineno}: {e.msg}")