Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
95 lines
3.9 KiB
Python
95 lines
3.9 KiB
Python
#!/usr/bin/python3
|
|
"""Replace the old bf16->NVFP4 handler with a simple bf16->FP8 handler."""
|
|
|
|
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
|
|
|
with open(filepath, 'r') as f:
|
|
lines = f.readlines()
|
|
|
|
# Find the handler: starts with the if check, ends with continue
|
|
start = None
|
|
end = None
|
|
for i, line in enumerate(lines):
|
|
if 'loaded_weight.dtype != torch.uint8' in line and 'param.data.dtype == torch.uint8' in line:
|
|
# Go back to find the if statement start
|
|
for j in range(i, max(i-3, 0), -1):
|
|
if lines[j].strip().startswith('if (name.endswith'):
|
|
start = j
|
|
break
|
|
if start is None:
|
|
start = i # fallback
|
|
break
|
|
|
|
if start is None:
|
|
print("Could not find handler start")
|
|
exit(1)
|
|
|
|
# Find the end: the first 'continue' at indent level 20+ after start
|
|
for i in range(start + 1, min(start + 200, len(lines))):
|
|
stripped = lines[i].strip()
|
|
if stripped == 'continue':
|
|
indent = len(lines[i]) - len(lines[i].lstrip())
|
|
if indent >= 20:
|
|
end = i
|
|
break
|
|
|
|
if end is None:
|
|
print("Could not find handler end")
|
|
exit(1)
|
|
|
|
print(f"Replacing lines {start+1} to {end+1} ({end-start+1} lines)")
|
|
print(f"First: {lines[start].rstrip()[:80]}")
|
|
print(f"Last: {lines[end].rstrip()[:80]}")
|
|
|
|
new_handler = [
|
|
' if (name.endswith(".weight")\n',
|
|
' and loaded_weight.dtype != torch.uint8\n',
|
|
' and param.data.dtype == torch.uint8):\n',
|
|
' # o_a_proj was NOT quantized by modelopt (bf16, no scales)\n',
|
|
' # wo_a is used with fp8_einsum: convert bf16 -> FP8 directly\n',
|
|
' w_bf16 = loaded_weight\n',
|
|
' w_amax = w_bf16.abs().amax()\n',
|
|
' if w_amax == 0:\n',
|
|
' w_amax = torch.tensor(1.0, device=w_bf16.device)\n',
|
|
' fp8_max = torch.finfo(torch.float8_e4m3fn).max\n',
|
|
' fp8_scale = w_amax / fp8_max\n',
|
|
' w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)\n',
|
|
' weight_scale_inv = fp8_scale.to(torch.float32)\n',
|
|
' parts = name.rsplit(".", 1)\n',
|
|
' module_path = parts[0]\n',
|
|
' mod = self\n',
|
|
' for attr in module_path.split("."):\n',
|
|
' if attr.isdigit():\n',
|
|
' mod = mod[int(attr)]\n',
|
|
' else:\n',
|
|
' mod = getattr(mod, attr)\n',
|
|
' mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)\n',
|
|
' mod.weight_scale_inv = torch.nn.Parameter(\n',
|
|
' weight_scale_inv.reshape(1), requires_grad=False\n',
|
|
' )\n',
|
|
' from vllm.model_executor.layers.linear import (\n',
|
|
' UnquantizedLinearMethod,\n',
|
|
' )\n',
|
|
' mod.quant_method = UnquantizedLinearMethod()\n',
|
|
' for attr in ("weight_scale", "weight_scale_2", "input_scale"):\n',
|
|
' if hasattr(mod, attr):\n',
|
|
' delattr(mod, attr)\n',
|
|
' loaded_params.add(name)\n',
|
|
' loaded_params.add(name.replace(".weight", ".weight_scale_inv"))\n',
|
|
' continue\n',
|
|
]
|
|
|
|
lines[start:end+1] = new_handler
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.writelines(lines)
|
|
|
|
import ast
|
|
try:
|
|
ast.parse(''.join(lines))
|
|
print("Syntax OK")
|
|
except SyntaxError as e:
|
|
print(f"Syntax error at line {e.lineno}: {e.msg}")
|
|
|
|
print(f"Replaced {end-start+1} lines with {len(new_handler)} lines")
|