Files
deepseek-v4-quant/tmp/apply_all_fixes.py
biondizzle 02b8ea536f Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge.
All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
2026-05-11 02:02:49 +00:00

263 lines
14 KiB
Python

#!/usr/bin/python3
"""
Apply ALL fixes to the S11 base version of deepseek_v4.py.
This is a clean application of all fixes we've developed.
"""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
import ast
def check(c, label):
try:
ast.parse(c)
print(f" {label}: OK")
return True
except SyntaxError as e:
print(f" {label}: SYNTAX ERROR at line {e.lineno}: {e.msg}")
return False
# ═══════════════════════════════════════════════════════════
# FIX 1: Substr mapping — remove .mla_attn. from attn projections
# ═══════════════════════════════════════════════════════════
subs = {
'".self_attn.q_a_proj.": ".attn.mla_attn.wq_a."': '".self_attn.q_a_proj.": ".attn.wq_a."',
'".self_attn.q_b_proj.": ".attn.mla_attn.wq_b."': '".self_attn.q_b_proj.": ".attn.wq_b."',
'".self_attn.q_a_norm.": ".attn.mla_attn.q_norm."': '".self_attn.q_a_norm.": ".attn.q_norm."',
'".self_attn.o_a_proj.": ".attn.mla_attn.wo_a."': '".self_attn.o_a_proj.": ".attn.wo_a."',
'".self_attn.o_b_proj.": ".attn.mla_attn.wo_b."': '".self_attn.o_b_proj.": ".attn.wo_b."',
'".self_attn.sinks": ".attn.mla_attn.attn_sink"': '".self_attn.sinks": ".attn.attn_sink"',
'".self_attn.kv_proj.": ".attn.mla_attn.wkv."': '".self_attn.kv_proj.": ".attn.wkv."',
'".self_attn.kv_norm.": ".attn.mla_attn.kv_norm."': '".self_attn.kv_norm.": ".attn.kv_norm."',
}
for old, new in subs.items():
c = c.replace(old, new)
check(c, "Fix 1 (substr)")
# ═══════════════════════════════════════════════════════════
# FIX 2: Skip patterns — only skip compressor scales
# ═══════════════════════════════════════════════════════════
# Remove attention and shared expert scale skip patterns
lines_to_remove = [
' re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale$"): None,',
' re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale$"): None,',
' re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale$"): None,',
' re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale$"): None,',
' re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale$"): None,',
' re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale_2$"): None,',
' re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale_2$"): None,',
' re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale_2$"): None,',
' re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale_2$"): None,',
' re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale_2$"): None,',
' re.compile(r"\\.self_attn\\.kv_proj\\.input_scale$"): None,',
' re.compile(r"\\.self_attn\\.q_a_proj\\.input_scale$"): None,',
' re.compile(r"\\.self_attn\\.q_b_proj\\.input_scale$"): None,',
' re.compile(r"\\.self_attn\\.o_a_proj\\.input_scale$"): None,',
' re.compile(r"\\.self_attn\\.o_b_proj\\.input_scale$"): None,',
' re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale$"): None,',
' re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale$"): None,',
' re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale_2$"): None,',
' re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale_2$"): None,',
' re.compile(r"\\.shared_experts\\.gate_proj\\.input_scale$"): None,',
' re.compile(r"\\.shared_experts\\.up_proj\\.input_scale$"): None,',
]
for line in lines_to_remove:
c = c.replace(line + "\n", "")
c = c.replace(line, "")
check(c, "Fix 2 (skip patterns)")
# ═══════════════════════════════════════════════════════════
# FIX 3: Remove the 'head.weight' suffix mapping that causes
# 'lm_head.weight' to become 'lm_lm_head.weight'
# ═══════════════════════════════════════════════════════════
c = c.replace(' "head.weight": "lm_head.weight",\n', '')
check(c, "Fix 3 (suffix)")
# ═══════════════════════════════════════════════════════════
# FIX 4: Handle o_a_proj bf16 -> FP8 at load time
# modelopt didn't quantize o_a_proj, but vLLM creates wo_a with NVFP4
# Convert bf16 -> FP8 and set weight_scale_inv
# ═══════════════════════════════════════════════════════════
old_else = ''' else:
if name not in params_dict:
# ModelOpt NVFP4 export includes params not in the
# vllm model (e.g., compressor.position_bias).
# Skip them silently.
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue'''
new_else = ''' else:
if name not in params_dict:
continue
param = params_dict[name]
# Handle o_a_proj bf16 -> wo_a uint8 mismatch
if (name.endswith(".weight")
and loaded_weight.dtype != torch.uint8
and param.data.dtype == torch.uint8):
# o_a_proj was NOT quantized by modelopt (bf16, no scales)
# Convert bf16 -> FP8 and set weight_scale_inv
w_bf16 = loaded_weight
w_amax = w_bf16.abs().amax()
if w_amax == 0:
w_amax = torch.tensor(1.0, device=w_bf16.device)
fp8_max = torch.finfo(torch.float8_e4m3fn).max
fp8_scale = w_amax / fp8_max
w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)
weight_scale_inv = fp8_scale.to(torch.float32)
# Replace the module weight and add weight_scale_inv
parts = name.rsplit(".", 1)
module_path = parts[0]
mod = self
for attr in module_path.split("."):
if attr.isdigit():
mod = mod[int(attr)]
else:
mod = getattr(mod, attr)
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
mod.weight_scale_inv = torch.nn.Parameter(
weight_scale_inv.reshape(1), requires_grad=False
)
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
loaded_params.add(name)
loaded_params.add(name.replace(".weight", ".weight_scale_inv"))
continue
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue'''
c = c.replace(old_else, new_else)
check(c, "Fix 4 (o_a_proj bf16->FP8)")
# ═══════════════════════════════════════════════════════════
# FIX 5: Add NVFP4->FP8 post-load conversion for attention
# This converts all uint8 NVFP4 attention weights to FP8
# ═══════════════════════════════════════════════════════════
conversion_methods = '''
def _convert_nvfp4_attention_to_fp8(self):
E2M1_LUT = torch.tensor(
[0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16
)
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
attn_proj_names = {"fused_wqa_wkv", "wq_b", "wo_a", "wo_b"}
shared_expert_names = {"gate_up_proj"}
converted = 0
for layer_idx, layer in enumerate(self.layers):
attn = layer.attn
for proj_name in attn_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
converted += 1
ffn = layer.ffn
if hasattr(ffn, "shared_experts"):
for proj_name in shared_expert_names:
if not hasattr(ffn.shared_experts, proj_name):
continue
mod = getattr(ffn.shared_experts, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
converted += 1
if converted > 0:
logger.info_once(
"Converted %d NVFP4 attention/shared-expert layers to FP8",
converted,
)
def _convert_nvfp4_module_to_fp8(self, mod, e2m1_lut, fp8_max):
w_uint8 = mod.weight.data
device = w_uint8.device
even_idx = (w_uint8 & 0x0F).int()
odd_idx = ((w_uint8 >> 4) & 0x0F).int()
even_vals = e2m1_lut.to(device)[even_idx]
odd_vals = e2m1_lut.to(device)[odd_idx]
w_bf16 = torch.stack([even_vals, odd_vals], dim=-1)
w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16)
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
block_scale = mod.weight_scale.data.to(torch.float32)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
-1, -1, block_size
).reshape(w_bf16.shape)
else:
block_scale_expanded = block_scale
global_scale = mod.weight_scale_2.data.max().item()
input_scale = (
mod.input_scale.data.max().item()
if hasattr(mod, "input_scale")
else 1.0
)
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
w_dequant = w_dequant.to(torch.bfloat16)
else:
w_dequant = w_bf16
w_amax = w_dequant.abs().amax()
if w_amax == 0:
w_amax = torch.tensor(1.0, device=device)
fp8_scale = w_amax / fp8_max
w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn)
weight_scale_inv = fp8_scale.to(torch.float32)
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
mod.weight_scale_inv = torch.nn.Parameter(
weight_scale_inv.reshape(1), requires_grad=False
)
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
mod.quant_method = UnquantizedLinearMethod()
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
if hasattr(mod, attr):
delattr(mod, attr)
'''
# Insert before DeepseekV4ForCausalLM class
marker = "\n\nclass DeepseekV4ForCausalLM(nn.Module):"
if marker in c:
c = c.replace(marker, "\n" + conversion_methods + "\nclass DeepseekV4ForCausalLM(nn.Module):")
print(" Fix 5: Inserted conversion methods")
else:
print(" Fix 5: Could not find class marker")
check(c, "Fix 5 (NVFP4->FP8 methods)")
# ═══════════════════════════════════════════════════════════
# FIX 6: Call the conversion from DeepseekV4ForCausalLM.load_weights
# ═══════════════════════════════════════════════════════════
old_load = " self.model.finalize_mega_moe_weights()\n return loaded_params"
new_load = " self.model.finalize_mega_moe_weights()\n self.model._convert_nvfp4_attention_to_fp8()\n return loaded_params"
c = c.replace(old_load, new_load)
check(c, "Fix 6 (call conversion)")
# ═══════════════════════════════════════════════════════════
# Final validation
# ═══════════════════════════════════════════════════════════
check(c, "FINAL")
with open(filepath, 'w') as f:
f.write(c)
print("All fixes applied!")