Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
263 lines
14 KiB
Python
263 lines
14 KiB
Python
#!/usr/bin/python3
|
|
"""
|
|
Apply ALL fixes to the S11 base version of deepseek_v4.py.
|
|
This is a clean application of all fixes we've developed.
|
|
"""
|
|
|
|
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
|
|
|
with open(filepath, 'r') as f:
|
|
c = f.read()
|
|
|
|
import ast
|
|
|
|
def check(c, label):
|
|
try:
|
|
ast.parse(c)
|
|
print(f" {label}: OK")
|
|
return True
|
|
except SyntaxError as e:
|
|
print(f" {label}: SYNTAX ERROR at line {e.lineno}: {e.msg}")
|
|
return False
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# FIX 1: Substr mapping — remove .mla_attn. from attn projections
|
|
# ═══════════════════════════════════════════════════════════
|
|
subs = {
|
|
'".self_attn.q_a_proj.": ".attn.mla_attn.wq_a."': '".self_attn.q_a_proj.": ".attn.wq_a."',
|
|
'".self_attn.q_b_proj.": ".attn.mla_attn.wq_b."': '".self_attn.q_b_proj.": ".attn.wq_b."',
|
|
'".self_attn.q_a_norm.": ".attn.mla_attn.q_norm."': '".self_attn.q_a_norm.": ".attn.q_norm."',
|
|
'".self_attn.o_a_proj.": ".attn.mla_attn.wo_a."': '".self_attn.o_a_proj.": ".attn.wo_a."',
|
|
'".self_attn.o_b_proj.": ".attn.mla_attn.wo_b."': '".self_attn.o_b_proj.": ".attn.wo_b."',
|
|
'".self_attn.sinks": ".attn.mla_attn.attn_sink"': '".self_attn.sinks": ".attn.attn_sink"',
|
|
'".self_attn.kv_proj.": ".attn.mla_attn.wkv."': '".self_attn.kv_proj.": ".attn.wkv."',
|
|
'".self_attn.kv_norm.": ".attn.mla_attn.kv_norm."': '".self_attn.kv_norm.": ".attn.kv_norm."',
|
|
}
|
|
for old, new in subs.items():
|
|
c = c.replace(old, new)
|
|
check(c, "Fix 1 (substr)")
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# FIX 2: Skip patterns — only skip compressor scales
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Remove attention and shared expert scale skip patterns
|
|
lines_to_remove = [
|
|
' re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale_2$"): None,',
|
|
' re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale_2$"): None,',
|
|
' re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale_2$"): None,',
|
|
' re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale_2$"): None,',
|
|
' re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale_2$"): None,',
|
|
' re.compile(r"\\.self_attn\\.kv_proj\\.input_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.q_a_proj\\.input_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.q_b_proj\\.input_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.o_a_proj\\.input_scale$"): None,',
|
|
' re.compile(r"\\.self_attn\\.o_b_proj\\.input_scale$"): None,',
|
|
' re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale$"): None,',
|
|
' re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale$"): None,',
|
|
' re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale_2$"): None,',
|
|
' re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale_2$"): None,',
|
|
' re.compile(r"\\.shared_experts\\.gate_proj\\.input_scale$"): None,',
|
|
' re.compile(r"\\.shared_experts\\.up_proj\\.input_scale$"): None,',
|
|
]
|
|
for line in lines_to_remove:
|
|
c = c.replace(line + "\n", "")
|
|
c = c.replace(line, "")
|
|
check(c, "Fix 2 (skip patterns)")
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# FIX 3: Remove the 'head.weight' suffix mapping that causes
|
|
# 'lm_head.weight' to become 'lm_lm_head.weight'
|
|
# ═══════════════════════════════════════════════════════════
|
|
c = c.replace(' "head.weight": "lm_head.weight",\n', '')
|
|
check(c, "Fix 3 (suffix)")
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# FIX 4: Handle o_a_proj bf16 -> FP8 at load time
|
|
# modelopt didn't quantize o_a_proj, but vLLM creates wo_a with NVFP4
|
|
# Convert bf16 -> FP8 and set weight_scale_inv
|
|
# ═══════════════════════════════════════════════════════════
|
|
old_else = ''' else:
|
|
if name not in params_dict:
|
|
# ModelOpt NVFP4 export includes params not in the
|
|
# vllm model (e.g., compressor.position_bias).
|
|
# Skip them silently.
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
continue'''
|
|
|
|
new_else = ''' else:
|
|
if name not in params_dict:
|
|
continue
|
|
param = params_dict[name]
|
|
|
|
# Handle o_a_proj bf16 -> wo_a uint8 mismatch
|
|
if (name.endswith(".weight")
|
|
and loaded_weight.dtype != torch.uint8
|
|
and param.data.dtype == torch.uint8):
|
|
# o_a_proj was NOT quantized by modelopt (bf16, no scales)
|
|
# Convert bf16 -> FP8 and set weight_scale_inv
|
|
w_bf16 = loaded_weight
|
|
w_amax = w_bf16.abs().amax()
|
|
if w_amax == 0:
|
|
w_amax = torch.tensor(1.0, device=w_bf16.device)
|
|
fp8_max = torch.finfo(torch.float8_e4m3fn).max
|
|
fp8_scale = w_amax / fp8_max
|
|
w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn)
|
|
weight_scale_inv = fp8_scale.to(torch.float32)
|
|
# Replace the module weight and add weight_scale_inv
|
|
parts = name.rsplit(".", 1)
|
|
module_path = parts[0]
|
|
mod = self
|
|
for attr in module_path.split("."):
|
|
if attr.isdigit():
|
|
mod = mod[int(attr)]
|
|
else:
|
|
mod = getattr(mod, attr)
|
|
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
|
|
mod.weight_scale_inv = torch.nn.Parameter(
|
|
weight_scale_inv.reshape(1), requires_grad=False
|
|
)
|
|
from vllm.model_executor.layers.linear import (
|
|
UnquantizedLinearMethod,
|
|
)
|
|
mod.quant_method = UnquantizedLinearMethod()
|
|
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
|
|
if hasattr(mod, attr):
|
|
delattr(mod, attr)
|
|
loaded_params.add(name)
|
|
loaded_params.add(name.replace(".weight", ".weight_scale_inv"))
|
|
continue
|
|
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
continue'''
|
|
|
|
c = c.replace(old_else, new_else)
|
|
check(c, "Fix 4 (o_a_proj bf16->FP8)")
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# FIX 5: Add NVFP4->FP8 post-load conversion for attention
|
|
# This converts all uint8 NVFP4 attention weights to FP8
|
|
# ═══════════════════════════════════════════════════════════
|
|
conversion_methods = '''
|
|
def _convert_nvfp4_attention_to_fp8(self):
|
|
E2M1_LUT = torch.tensor(
|
|
[0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16
|
|
)
|
|
FP8_MAX = torch.finfo(torch.float8_e4m3fn).max
|
|
attn_proj_names = {"fused_wqa_wkv", "wq_b", "wo_a", "wo_b"}
|
|
shared_expert_names = {"gate_up_proj"}
|
|
converted = 0
|
|
for layer_idx, layer in enumerate(self.layers):
|
|
attn = layer.attn
|
|
for proj_name in attn_proj_names:
|
|
if not hasattr(attn, proj_name):
|
|
continue
|
|
mod = getattr(attn, proj_name)
|
|
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
|
continue
|
|
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
|
converted += 1
|
|
ffn = layer.ffn
|
|
if hasattr(ffn, "shared_experts"):
|
|
for proj_name in shared_expert_names:
|
|
if not hasattr(ffn.shared_experts, proj_name):
|
|
continue
|
|
mod = getattr(ffn.shared_experts, proj_name)
|
|
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
|
continue
|
|
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
|
converted += 1
|
|
if converted > 0:
|
|
logger.info_once(
|
|
"Converted %d NVFP4 attention/shared-expert layers to FP8",
|
|
converted,
|
|
)
|
|
|
|
def _convert_nvfp4_module_to_fp8(self, mod, e2m1_lut, fp8_max):
|
|
w_uint8 = mod.weight.data
|
|
device = w_uint8.device
|
|
even_idx = (w_uint8 & 0x0F).int()
|
|
odd_idx = ((w_uint8 >> 4) & 0x0F).int()
|
|
even_vals = e2m1_lut.to(device)[even_idx]
|
|
odd_vals = e2m1_lut.to(device)[odd_idx]
|
|
w_bf16 = torch.stack([even_vals, odd_vals], dim=-1)
|
|
w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16)
|
|
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
|
block_scale = mod.weight_scale.data.to(torch.float32)
|
|
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
|
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
|
block_scale_expanded = block_scale.unsqueeze(-1).expand(
|
|
-1, -1, block_size
|
|
).reshape(w_bf16.shape)
|
|
else:
|
|
block_scale_expanded = block_scale
|
|
global_scale = mod.weight_scale_2.data.max().item()
|
|
input_scale = (
|
|
mod.input_scale.data.max().item()
|
|
if hasattr(mod, "input_scale")
|
|
else 1.0
|
|
)
|
|
w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale
|
|
w_dequant = w_dequant.to(torch.bfloat16)
|
|
else:
|
|
w_dequant = w_bf16
|
|
w_amax = w_dequant.abs().amax()
|
|
if w_amax == 0:
|
|
w_amax = torch.tensor(1.0, device=device)
|
|
fp8_scale = w_amax / fp8_max
|
|
w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn)
|
|
weight_scale_inv = fp8_scale.to(torch.float32)
|
|
mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False)
|
|
mod.weight_scale_inv = torch.nn.Parameter(
|
|
weight_scale_inv.reshape(1), requires_grad=False
|
|
)
|
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
|
mod.quant_method = UnquantizedLinearMethod()
|
|
for attr in ("weight_scale", "weight_scale_2", "input_scale"):
|
|
if hasattr(mod, attr):
|
|
delattr(mod, attr)
|
|
|
|
'''
|
|
|
|
# Insert before DeepseekV4ForCausalLM class
|
|
marker = "\n\nclass DeepseekV4ForCausalLM(nn.Module):"
|
|
if marker in c:
|
|
c = c.replace(marker, "\n" + conversion_methods + "\nclass DeepseekV4ForCausalLM(nn.Module):")
|
|
print(" Fix 5: Inserted conversion methods")
|
|
else:
|
|
print(" Fix 5: Could not find class marker")
|
|
|
|
check(c, "Fix 5 (NVFP4->FP8 methods)")
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# FIX 6: Call the conversion from DeepseekV4ForCausalLM.load_weights
|
|
# ═══════════════════════════════════════════════════════════
|
|
old_load = " self.model.finalize_mega_moe_weights()\n return loaded_params"
|
|
new_load = " self.model.finalize_mega_moe_weights()\n self.model._convert_nvfp4_attention_to_fp8()\n return loaded_params"
|
|
c = c.replace(old_load, new_load)
|
|
check(c, "Fix 6 (call conversion)")
|
|
|
|
# ═══════════════════════════════════════════════════════════
|
|
# Final validation
|
|
# ═══════════════════════════════════════════════════════════
|
|
check(c, "FINAL")
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.write(c)
|
|
|
|
print("All fixes applied!")
|