Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
This commit is contained in:
94
tmp/fix_replace_handler.py
Normal file
94
tmp/fix_replace_handler.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/python3
|
||||
"""Replace the old bf16->NVFP4 handler with a simple bf16->FP8 handler."""
|
||||
|
||||
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
||||
|
||||
with open(filepath, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find the handler: starts with the if check, ends with continue
|
||||
start = None
|
||||
end = None
|
||||
for i, line in enumerate(lines):
|
||||
if 'loaded_weight.dtype != torch.uint8' in line and 'param.data.dtype == torch.uint8' in line:
|
||||
# Go back to find the if statement start
|
||||
for j in range(i, max(i-3, 0), -1):
|
||||
if lines[j].strip().startswith('if (name.endswith'):
|
||||
start = j
|
||||
break
|
||||
if start is None:
|
||||
start = i # fallback
|
||||
break
|
||||
|
||||
if start is None:
|
||||
print("Could not find handler start")
|
||||
exit(1)
|
||||
|
||||
# Find the end: the first 'continue' at indent level 20+ after start
|
||||
for i in range(start + 1, min(start + 200, len(lines))):
|
||||
stripped = lines[i].strip()
|
||||
if stripped == 'continue':
|
||||
indent = len(lines[i]) - len(lines[i].lstrip())
|
||||
if indent >= 20:
|
||||
end = i
|
||||
break
|
||||
|
||||
if end is None:
|
||||
print("Could not find handler end")
|
||||
exit(1)
|
||||
|
||||
print(f"Replacing lines {start+1} to {end+1} ({end-start+1} lines)")
|
||||
print(f"First: {lines[start].rstrip()[:80]}")
|
||||
print(f"Last: {lines[end].rstrip()[:80]}")
|
||||
|
||||
new_handler = [
|
||||
' if (name.endswith(".weight")\n',
|
||||
' and loaded_weight.dtype != torch.uint8\n',
|
||||
' and param.data.dtype == torch.uint8):\n',
|
||||
' # o_a_proj was NOT quantized by modelopt (bf16, no scales)\n',
|
||||
' # wo_a is used with fp8_einsum: convert bf16 -> FP8 directly\n',
|
||||
' w_bf16 = loaded_weight\n',
|
||||
' w_amax = w_bf16.abs().amax()\n',
|
||||
' if w_amax == 0:\n',
|
||||
' w_amax = torch.tensor(1.0, device=w_bf16.device)\n',
|
||||
' fp8_max = torch.finfo(torch.float8_e4m3fn).max\n',
|
||||
' fp8_scale = w_amax / fp8_max\n',
|
||||
' w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)\n',
|
||||
' weight_scale_inv = fp8_scale.to(torch.float32)\n',
|
||||
' parts = name.rsplit(".", 1)\n',
|
||||
' module_path = parts[0]\n',
|
||||
' mod = self\n',
|
||||
' for attr in module_path.split("."):\n',
|
||||
' if attr.isdigit():\n',
|
||||
' mod = mod[int(attr)]\n',
|
||||
' else:\n',
|
||||
' mod = getattr(mod, attr)\n',
|
||||
' mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)\n',
|
||||
' mod.weight_scale_inv = torch.nn.Parameter(\n',
|
||||
' weight_scale_inv.reshape(1), requires_grad=False\n',
|
||||
' )\n',
|
||||
' from vllm.model_executor.layers.linear import (\n',
|
||||
' UnquantizedLinearMethod,\n',
|
||||
' )\n',
|
||||
' mod.quant_method = UnquantizedLinearMethod()\n',
|
||||
' for attr in ("weight_scale", "weight_scale_2", "input_scale"):\n',
|
||||
' if hasattr(mod, attr):\n',
|
||||
' delattr(mod, attr)\n',
|
||||
' loaded_params.add(name)\n',
|
||||
' loaded_params.add(name.replace(".weight", ".weight_scale_inv"))\n',
|
||||
' continue\n',
|
||||
]
|
||||
|
||||
lines[start:end+1] = new_handler
|
||||
|
||||
with open(filepath, 'w') as f:
|
||||
f.writelines(lines)
|
||||
|
||||
import ast
|
||||
try:
|
||||
ast.parse(''.join(lines))
|
||||
print("Syntax OK")
|
||||
except SyntaxError as e:
|
||||
print(f"Syntax error at line {e.lineno}: {e.msg}")
|
||||
|
||||
print(f"Replaced {end-start+1} lines with {len(new_handler)} lines")
|
||||
Reference in New Issue
Block a user