Files
deepseek-v4-quant/tmp/fix9_oa.py
biondizzle 02b8ea536f Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge.
All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
2026-05-11 02:02:49 +00:00

69 lines
3.4 KiB
Python

#!/usr/bin/env python3
"""Update the bf16->uint8 handler to convert bf16->FP8 directly"""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
# The bf16->uint8 handler needs to convert bf16 weight directly to FP8
# since o_a_proj was NOT quantized by modelopt
# Replace the entire handler block
old_handler = ''' if (name.endswith(".weight")
and loaded_weight.dtype != torch.uint8
and param.data.dtype == torch.uint8):
# Quantize bf16 → NVFP4 (E2M1 packed uint8 + scales)'''
new_handler = ''' if (name.endswith(".weight")
and loaded_weight.dtype != torch.uint8
and param.data.dtype == torch.uint8):
# o_a_proj was NOT quantized by modelopt (bf16, no scales)
# Convert bf16 → FP8 directly, set weight_scale_inv
w_bf16 = loaded_weight
w_amax = w_bf16.abs().amax()
if w_amax == 0:
w_amax = torch.tensor(1.0, device=w_bf16.device)
fp8_max = torch.finfo(torch.float8_e4m3fn).max
fp8_scale = w_amax / fp8_max
w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)
weight_scale_inv = fp8_scale.to(torch.float32)
# Load FP8 weight directly (bypass the uint8 param)
# Find the module and replace weight + quant method
parts = name.rsplit(".", 1)
module_path = parts[0]
mod = self
for attr in module_path.split("."):
if attr.isdigit():
mod = mod[int(attr)]
else:
mod = getattr(mod, attr)
# Replace weight param with FP8 version
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
mod.weight_scale_inv = torch.nn.Parameter(
weight_scale_inv.reshape(1), requires_grad=False
)
# Switch quant method to FP8 linear
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Fp8MMQuantMethod,
)
mod.quant_method = Fp8LinearMethod(Fp8MMQuantMethod())
# Clean up NVFP4 params
for attr in ('weight_scale', 'weight_scale_2', 'input_scale'):
if hasattr(mod, attr):
delattr(mod, attr)
loaded_params.add(name)
loaded_params.add(name.replace('.weight', '.weight_scale_inv'))
continue
# OLD: Quantize bf16 -> NVFP4 (E2M1 packed uint8 + scales)'''
c = c.replace(old_handler, new_handler)
with open(filepath, 'w') as f:
f.write(c)
print("Updated bf16->uint8 handler to convert to FP8 directly")