#!/usr/bin/python3 """Replace the old bf16->NVFP4 handler with a simple bf16->FP8 handler.""" filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: lines = f.readlines() # Find the handler: starts with the if check, ends with continue start = None end = None for i, line in enumerate(lines): if 'loaded_weight.dtype != torch.uint8' in line and 'param.data.dtype == torch.uint8' in line: # Go back to find the if statement start for j in range(i, max(i-3, 0), -1): if lines[j].strip().startswith('if (name.endswith'): start = j break if start is None: start = i # fallback break if start is None: print("Could not find handler start") exit(1) # Find the end: the first 'continue' at indent level 20+ after start for i in range(start + 1, min(start + 200, len(lines))): stripped = lines[i].strip() if stripped == 'continue': indent = len(lines[i]) - len(lines[i].lstrip()) if indent >= 20: end = i break if end is None: print("Could not find handler end") exit(1) print(f"Replacing lines {start+1} to {end+1} ({end-start+1} lines)") print(f"First: {lines[start].rstrip()[:80]}") print(f"Last: {lines[end].rstrip()[:80]}") new_handler = [ ' if (name.endswith(".weight")\n', ' and loaded_weight.dtype != torch.uint8\n', ' and param.data.dtype == torch.uint8):\n', ' # o_a_proj was NOT quantized by modelopt (bf16, no scales)\n', ' # wo_a is used with fp8_einsum: convert bf16 -> FP8 directly\n', ' w_bf16 = loaded_weight\n', ' w_amax = w_bf16.abs().amax()\n', ' if w_amax == 0:\n', ' w_amax = torch.tensor(1.0, device=w_bf16.device)\n', ' fp8_max = torch.finfo(torch.float8_e4m3fn).max\n', ' fp8_scale = w_amax / fp8_max\n', ' w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)\n', ' weight_scale_inv = fp8_scale.to(torch.float32)\n', ' parts = name.rsplit(".", 1)\n', ' module_path = parts[0]\n', ' mod = self\n', ' for attr in module_path.split("."):\n', ' if attr.isdigit():\n', ' mod = mod[int(attr)]\n', ' else:\n', ' mod = getattr(mod, attr)\n', ' mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)\n', ' mod.weight_scale_inv = torch.nn.Parameter(\n', ' weight_scale_inv.reshape(1), requires_grad=False\n', ' )\n', ' from vllm.model_executor.layers.linear import (\n', ' UnquantizedLinearMethod,\n', ' )\n', ' mod.quant_method = UnquantizedLinearMethod()\n', ' for attr in ("weight_scale", "weight_scale_2", "input_scale"):\n', ' if hasattr(mod, attr):\n', ' delattr(mod, attr)\n', ' loaded_params.add(name)\n', ' loaded_params.add(name.replace(".weight", ".weight_scale_inv"))\n', ' continue\n', ] lines[start:end+1] = new_handler with open(filepath, 'w') as f: f.writelines(lines) import ast try: ast.parse(''.join(lines)) print("Syntax OK") except SyntaxError as e: print(f"Syntax error at line {e.lineno}: {e.msg}") print(f"Replaced {end-start+1} lines with {len(new_handler)} lines")