Files
deepseek-v4-quant/tmp/fix_oa_fp8.py
biondizzle 02b8ea536f Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge.
All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
2026-05-11 02:02:49 +00:00

79 lines
3.3 KiB
Python

#!/usr/bin/python3
"""Replace the bf16->NVFP4 quantization handler with a simple bf16->FP8 conversion.
wo_a is used with fp8_einsum, so it needs FP8 weight + weight_scale_inv."""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
# Find and replace the entire bf16->uint8 handler block
# It starts with the o_a_proj comment and ends with "continue"
import re
# Find the handler
m = re.search(
r"( # Handle o_a_proj bf16 → wo_a uint8 mismatch:.*?)(\n continue\n)",
c,
re.DOTALL,
)
if m:
# Replace with bf16->FP8 conversion
new_handler = ''' # Handle o_a_proj bf16 -> wo_a: convert to FP8 directly
# (wo_a is used with fp8_einsum, needs FP8 + weight_scale_inv)
if (name.endswith(".weight")
and loaded_weight.dtype != torch.uint8
and param.data.dtype == torch.uint8):
w_bf16 = loaded_weight
w_amax = w_bf16.abs().amax()
if w_amax == 0:
w_amax = torch.tensor(1.0, device=w_bf16.device)
fp8_max = torch.finfo(torch.float8_e4m3fn).max
fp8_scale = w_amax / fp8_max
w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)
weight_scale_inv = fp8_scale.to(torch.float32)
parts = name.rsplit(".", 1)
module_path = parts[0]
mod = self
for attr in module_path.split("."):
if attr.isdigit():
mod = mod[int(attr)]
else:
mod = getattr(mod, attr)
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
mod.weight_scale_inv = torch.nn.Parameter(
weight_scale_inv.reshape(1), requires_grad=False
)
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
loaded_params.add(name)
loaded_params.add(name.replace(".weight", ".weight_scale_inv"))
continue
'''
c = c[:m.start()] + new_handler + c[m.end():]
print("Replaced bf16->NVFP4 handler with bf16->FP8 handler")
else:
print("Could not find handler block, trying alternate search...")
# Try finding just the if condition
idx = c.find("and loaded_weight.dtype != torch.uint8\n")
if idx > 0:
print(f"Found condition at position {idx}")
else:
print("ERROR: Could not find condition")
with open(filepath, 'w') as f:
f.write(c)
import ast
try:
ast.parse(c)
print("Syntax OK")
except SyntaxError as e:
print(f"Syntax error at line {e.lineno}: {e.msg}")