Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
79 lines
3.3 KiB
Python
79 lines
3.3 KiB
Python
#!/usr/bin/python3
|
|
"""Replace the bf16->NVFP4 quantization handler with a simple bf16->FP8 conversion.
|
|
wo_a is used with fp8_einsum, so it needs FP8 weight + weight_scale_inv."""
|
|
|
|
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
|
|
|
with open(filepath, 'r') as f:
|
|
c = f.read()
|
|
|
|
# Find and replace the entire bf16->uint8 handler block
|
|
# It starts with the o_a_proj comment and ends with "continue"
|
|
import re
|
|
|
|
# Find the handler
|
|
m = re.search(
|
|
r"( # Handle o_a_proj bf16 → wo_a uint8 mismatch:.*?)(\n continue\n)",
|
|
c,
|
|
re.DOTALL,
|
|
)
|
|
|
|
if m:
|
|
# Replace with bf16->FP8 conversion
|
|
new_handler = ''' # Handle o_a_proj bf16 -> wo_a: convert to FP8 directly
|
|
# (wo_a is used with fp8_einsum, needs FP8 + weight_scale_inv)
|
|
if (name.endswith(".weight")
|
|
and loaded_weight.dtype != torch.uint8
|
|
and param.data.dtype == torch.uint8):
|
|
w_bf16 = loaded_weight
|
|
w_amax = w_bf16.abs().amax()
|
|
if w_amax == 0:
|
|
w_amax = torch.tensor(1.0, device=w_bf16.device)
|
|
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
|
fp8_scale = w_amax / fp8_max
|
|
w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)
|
|
weight_scale_inv = fp8_scale.to(torch.float32)
|
|
parts = name.rsplit(".", 1)
|
|
module_path = parts[0]
|
|
mod = self
|
|
for attr in module_path.split("."):
|
|
if attr.isdigit():
|
|
mod = mod[int(attr)]
|
|
else:
|
|
mod = getattr(mod, attr)
|
|
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
|
|
mod.weight_scale_inv = torch.nn.Parameter(
|
|
weight_scale_inv.reshape(1), requires_grad=False
|
|
)
|
|
from vllm.model_executor.layers.linear import (
|
|
UnquantizedLinearMethod,
|
|
)
|
|
mod.quant_method = UnquantizedLinearMethod()
|
|
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
|
|
if hasattr(mod, attr):
|
|
delattr(mod, attr)
|
|
loaded_params.add(name)
|
|
loaded_params.add(name.replace(".weight", ".weight_scale_inv"))
|
|
continue
|
|
'''
|
|
c = c[:m.start()] + new_handler + c[m.end():]
|
|
print("Replaced bf16->NVFP4 handler with bf16->FP8 handler")
|
|
else:
|
|
print("Could not find handler block, trying alternate search...")
|
|
# Try finding just the if condition
|
|
idx = c.find("and loaded_weight.dtype != torch.uint8\n")
|
|
if idx > 0:
|
|
print(f"Found condition at position {idx}")
|
|
else:
|
|
print("ERROR: Could not find condition")
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.write(c)
|
|
|
|
import ast
|
|
try:
|
|
ast.parse(c)
|
|
print("Syntax OK")
|
|
except SyntaxError as e:
|
|
print(f"Syntax error at line {e.lineno}: {e.msg}")
|