Files
deepseek-v4-quant/tmp/fix_clean_conversion.py
biondizzle 02b8ea536f Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge.
All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
2026-05-11 02:02:49 +00:00

224 lines
8.9 KiB
Python

#!/usr/bin/python3
"""
Clean rewrite of the NVFP4→FP8/bf16 conversion.
Strategy:
- wo_a, fused_wqa_wkv → FP8 (used with fp8_einsum, need weight_scale_inv)
- wq_b, wo_b, gate_up_proj → bf16 (used via .forward(), just works)
- compressor fused_wkv_wgate → bf16 (already handled in load path)
- MoE experts → native NVFP4 (ModelOptNvFp4FusedMoE handles it)
"""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
# Find and replace the entire _convert_nvfp4_attention_to_fp8 method
# and _convert_nvfp4_module_to_fp8 method
import re
# Remove old methods and insert new ones
# Find the method definitions
m1_start = c.find(" def _convert_nvfp4_attention_to_fp8(self):")
if m1_start < 0:
print("ERROR: Could not find _convert_nvfp4_attention_to_fp8")
exit(1)
# Find the end: look for the next method/class at the same or lower indent
# after _convert_nvfp4_module_to_fp8
m2_start = c.find(" def _convert_nvfp4_module_to_fp8(self, mod, e2m1_lut, fp8_max):", m1_start)
if m2_start < 0:
print("ERROR: Could not find _convert_nvfp4_module_to_fp8")
exit(1)
# Find the end of the second method
# Scan for the next line at indent <= 4 that's not blank
pos = m2_start
lines_after = c[m2_start:].split('\n')
end_line = 0
for i, line in enumerate(lines_after[1:], 1):
if line.strip() == '':
continue
indent = len(line) - len(line.lstrip())
if indent <= 4:
end_line = i
break
# Calculate the end position
end_pos = m2_start + sum(len(l) + 1 for l in lines_after[:end_line])
new_methods = ''' def _convert_nvfp4_post_load(self):
"""Post-load conversion of NVFP4 weights for vLLM compatibility.
Strategy:
- wo_a, fused_wqa_wkv: Convert NVFP4->FP8 (used with fp8_einsum)
- wq_b, wo_b, gate_up_proj: Dequant NVFP4->bf16 (used via .forward())
- MoE experts: Stay in native NVFP4 (ModelOptNvFp4FusedMoE)
"""
E2M1_LUT = torch.tensor(
[0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16
)
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
# Layers that use fp8_einsum (need FP8 + weight_scale_inv)
fp8_proj_names = {"fused_wqa_wkv", "wo_a"}
# Layers that use normal .forward() (need bf16)
bf16_proj_names = {"wq_b", "wo_b"}
bf16_shared_names = {"gate_up_proj"}
fp8_converted = 0
bf16_converted = 0
for layer_idx, layer in enumerate(self.layers):
attn = layer.attn
for proj_name in fp8_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._convert_nvfp4_to_fp8(mod, E2M1_LUT, FP8_MAX)
fp8_converted += 1
for proj_name in bf16_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
# Shared experts
ffn = layer.ffn
if hasattr(ffn, "shared_experts"):
for proj_name in bf16_shared_names:
if not hasattr(ffn.shared_experts, proj_name):
continue
mod = getattr(ffn.shared_experts, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._dequant_nvfp4_to_bf16(mod, E2M1_LUT)
bf16_converted += 1
if fp8_converted > 0 or bf16_converted > 0:
print(f"NVFP4 post-load: {fp8_converted} layers -> FP8, "
f"{bf16_converted} layers -> bf16, MoE experts stay NVFP4")
def _dequant_nvfp4_to_bf16(self, mod, e2m1_lut):
"""Dequantize NVFP4 weight to bf16 for normal .forward() path."""
w_uint8 = mod.weight.data
device = w_uint8.device
w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device)
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = mod.weight_scale.data.to(torch.float32)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
-1, -1, block_size
).reshape(w_bf16.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
input_scale = (
mod.input_scale.data.max().item()
if hasattr(mod, "input_scale")
else 1.0
)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
w_dequant = w_dequant.to(torch.bfloat16)
else:
w_dequant = w_bf16
# Replace weight with bf16 version
mod.weight = torch.nn.Parameter(w_dequant, requires_grad=False)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale",
"weight_scale_inv"):
if hasattr(mod, attr):
delattr(mod, attr)
def _convert_nvfp4_to_fp8(self, mod, e2m1_lut, fp8_max):
"""Convert NVFP4 weight to FP8 for fp8_einsum path."""
w_uint8 = mod.weight.data
device = w_uint8.device
w_bf16 = self._unpack_nvfp4_to_bf16(w_uint8, e2m1_lut, device)
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = mod.weight_scale.data.to(torch.float32)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
-1, -1, block_size
).reshape(w_bf16.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
input_scale = (
mod.input_scale.data.max().item()
if hasattr(mod, "input_scale")
else 1.0
)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
w_dequant = w_dequant.to(torch.bfloat16)
else:
w_dequant = w_bf16
# Re-quantize bf16 -> FP8 e4m3
w_amax = w_dequant.abs().amax()
if w_amax == 0:
w_amax = torch.tensor(1.0, device=device)
fp8_scale = w_amax / fp8_max
w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn)
weight_scale_inv = fp8_scale.to(torch.float32)
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
mod.weight_scale_inv = torch.nn.Parameter(
weight_scale_inv.reshape(1), requires_grad=False
)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):
"""Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format."""
# Extract 4-bit FP4 values (0-15, bit 3 = sign)
even_raw = (w_uint8 & 0x0F).int()
odd_raw = ((w_uint8 >> 4) & 0x0F).int()
# Sign: 0-7 = positive, 8-15 = negative
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
# Magnitude index: lower 3 bits (0-7)
even_vals = even_sign * e2m1_lut.to(device)[even_raw & 0x07]
odd_vals = odd_sign * e2m1_lut.to(device)[odd_raw & 0x07]
# Interleave and flatten
w_bf16 = torch.stack([even_vals, odd_vals], dim=-1)
w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16)
return w_bf16
'''
c = c[:m1_start] + new_methods + c[end_pos:]
# Also update the call from DeepseekV4ForCausalLM.load_weights
c = c.replace(
"self.model._convert_nvfp4_attention_to_fp8()",
"self.model._convert_nvfp4_post_load()"
)
with open(filepath, 'w') as f:
f.write(c)
import ast
try:
ast.parse(c)
print("Syntax OK")
except SyntaxError as e:
print(f"Syntax error at line {e.lineno}: {e.msg}")
print("Replaced conversion methods with clean FP8/bf16 split")