Files
deepseek-v4-quant/tmp/fix_replace_handler.py

95 lines
3.9 KiB
Python
Raw Normal View History

#!/usr/bin/python3
"""Replace the old bf16->NVFP4 handler with a simple bf16->FP8 handler."""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
lines = f.readlines()
# Find the handler: starts with the if check, ends with continue
start = None
end = None
for i, line in enumerate(lines):
if 'loaded_weight.dtype != torch.uint8' in line and 'param.data.dtype == torch.uint8' in line:
# Go back to find the if statement start
for j in range(i, max(i-3, 0), -1):
if lines[j].strip().startswith('if (name.endswith'):
start = j
break
if start is None:
start = i # fallback
break
if start is None:
print("Could not find handler start")
exit(1)
# Find the end: the first 'continue' at indent level 20+ after start
for i in range(start + 1, min(start + 200, len(lines))):
stripped = lines[i].strip()
if stripped == 'continue':
indent = len(lines[i]) - len(lines[i].lstrip())
if indent >= 20:
end = i
break
if end is None:
print("Could not find handler end")
exit(1)
print(f"Replacing lines {start+1} to {end+1} ({end-start+1} lines)")
print(f"First: {lines[start].rstrip()[:80]}")
print(f"Last: {lines[end].rstrip()[:80]}")
new_handler = [
' if (name.endswith(".weight")\n',
' and loaded_weight.dtype != torch.uint8\n',
' and param.data.dtype == torch.uint8):\n',
' # o_a_proj was NOT quantized by modelopt (bf16, no scales)\n',
' # wo_a is used with fp8_einsum: convert bf16 -> FP8 directly\n',
' w_bf16 = loaded_weight\n',
' w_amax = w_bf16.abs().amax()\n',
' if w_amax == 0:\n',
' w_amax = torch.tensor(1.0, device=w_bf16.device)\n',
' fp8_max = torch.finfo(torch.float8_e4m3fn).max\n',
' fp8_scale = w_amax / fp8_max\n',
' w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)\n',
' weight_scale_inv = fp8_scale.to(torch.float32)\n',
' parts = name.rsplit(".", 1)\n',
' module_path = parts[0]\n',
' mod = self\n',
' for attr in module_path.split("."):\n',
' if attr.isdigit():\n',
' mod = mod[int(attr)]\n',
' else:\n',
' mod = getattr(mod, attr)\n',
' mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)\n',
' mod.weight_scale_inv = torch.nn.Parameter(\n',
' weight_scale_inv.reshape(1), requires_grad=False\n',
' )\n',
' from vllm.model_executor.layers.linear import (\n',
' UnquantizedLinearMethod,\n',
' )\n',
' mod.quant_method = UnquantizedLinearMethod()\n',
' for attr in ("weight_scale", "weight_scale_2", "input_scale"):\n',
' if hasattr(mod, attr):\n',
' delattr(mod, attr)\n',
' loaded_params.add(name)\n',
' loaded_params.add(name.replace(".weight", ".weight_scale_inv"))\n',
' continue\n',
]
lines[start:end+1] = new_handler
with open(filepath, 'w') as f:
f.writelines(lines)
import ast
try:
ast.parse(''.join(lines))
print("Syntax OK")
except SyntaxError as e:
print(f"Syntax error at line {e.lineno}: {e.msg}")
print(f"Replaced {end-start+1} lines with {len(new_handler)} lines")