Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
191 lines
10 KiB
Python
191 lines
10 KiB
Python
#!/usr/bin/env python3
|
|
"""Apply all NVFP4 serving fixes to deepseek_v4.py"""
|
|
|
|
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
|
|
|
with open(filepath, 'r') as f:
|
|
c = f.read()
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# FIX 1: Substr mapping — remove .mla_attn. from attention projections
|
|
# The model has fused_wqa_wkv, wq_b, wo_a, wo_b at attn.* level
|
|
# ═══════════════════════════════════════════════════════════════
|
|
replacements_1 = {
|
|
'".self_attn.q_a_proj.": ".attn.mla_attn.wq_a."': '".self_attn.q_a_proj.": ".attn.wq_a."',
|
|
'".self_attn.q_b_proj.": ".attn.mla_attn.wq_b."': '".self_attn.q_b_proj.": ".attn.wq_b."',
|
|
'".self_attn.q_a_norm.": ".attn.mla_attn.q_norm."': '".self_attn.q_a_norm.": ".attn.q_norm."',
|
|
'".self_attn.o_a_proj.": ".attn.mla_attn.wo_a."': '".self_attn.o_a_proj.": ".attn.wo_a."',
|
|
'".self_attn.o_b_proj.": ".attn.mla_attn.wo_b."': '".self_attn.o_b_proj.": ".attn.wo_b."',
|
|
'".self_attn.sinks": ".attn.mla_attn.attn_sink"': '".self_attn.sinks": ".attn.attn_sink"',
|
|
'".self_attn.kv_proj.": ".attn.mla_attn.wkv."': '".self_attn.kv_proj.": ".attn.wkv."',
|
|
'".self_attn.kv_norm.": ".attn.mla_attn.kv_norm."': '".self_attn.kv_norm.": ".attn.kv_norm."',
|
|
}
|
|
for old, new in replacements_1.items():
|
|
if old in c:
|
|
c = c.replace(old, new)
|
|
print(f" Fixed: {old[:50]}... → {new[:50]}...")
|
|
else:
|
|
print(f" NOT FOUND: {old[:60]}...")
|
|
|
|
# Update comment
|
|
c = c.replace(
|
|
'# Attention: self_attn → attn.mla_attn',
|
|
'# Attention: self_attn → attn (projections at attn level, not mla_attn)'
|
|
)
|
|
print("FIX 1 applied: substr mappings updated\n")
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# FIX 2: Skip patterns — only skip compressor scale tensors
|
|
# Attention and shared expert scale tensors now correctly load
|
|
# ═══════════════════════════════════════════════════════════════
|
|
old_skip_block = ''' fused_skip_regex = {
|
|
# Compressor projections → fused_wkv_wgate (stacked)
|
|
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.compressor\\.kv_proj\\.input_scale$"): None,
|
|
re.compile(r"\\.compressor\\.gate_proj\\.input_scale$"): None,
|
|
# Attention projections → fused_wqa_wkv (stacked)
|
|
re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.self_attn\\.kv_proj\\.input_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.q_a_proj\\.input_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.q_b_proj\\.input_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.o_a_proj\\.input_scale$"): None,
|
|
re.compile(r"\\.self_attn\\.o_b_proj\\.input_scale$"): None,
|
|
# Shared expert gate_proj/up_proj → gate_up_proj (stacked)
|
|
re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.shared_experts\\.gate_proj\\.input_scale$"): None,
|
|
re.compile(r"\\.shared_experts\\.up_proj\\.input_scale$"): None,
|
|
}'''
|
|
|
|
new_skip_block = ''' fused_skip_regex = {
|
|
# Compressor projections → fused_wkv_wgate (stacked)
|
|
# Compressor uses UnquantizedLinearMethod (quant_config=None),
|
|
# so it only has a bf16 weight param — no scale params registered.
|
|
# We unpack the NVFP4 uint8 weights to bf16 at load time.
|
|
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale$"): None,
|
|
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale_2$"): None,
|
|
re.compile(r"\\.compressor\\.kv_proj\\.input_scale$"): None,
|
|
re.compile(r"\\.compressor\\.gate_proj\\.input_scale$"): None,
|
|
# Note: attention and shared expert scale tensors are NO LONGER
|
|
# skipped. After fixing substr mappings, they correctly map to the
|
|
# model's NVFP4 scale parameters (fused_wqa_wkv, wq_b, wo_a,
|
|
# wo_b, gate_up_proj). They load via the stacking logic.
|
|
}'''
|
|
|
|
if old_skip_block in c:
|
|
c = c.replace(old_skip_block, new_skip_block)
|
|
print("FIX 2 applied: skip patterns updated (only compressor scales skipped)\n")
|
|
else:
|
|
print("FIX 2: Could not find old skip block, searching for alternatives...")
|
|
# Try a more flexible search
|
|
import re
|
|
# Find the fused_skip_regex block
|
|
m = re.search(r' fused_skip_regex = \{[^}]+\}', c)
|
|
if m:
|
|
print(f" Found block at position {m.start()}")
|
|
else:
|
|
print(" Could not find fused_skip_regex block!")
|
|
|
|
# ═══════════════════════════════════════════════════════════════
|
|
# FIX 3: Handle o_a_proj bf16 → wo_a uint8 mismatch
|
|
# modelopt didn't quantize o_a_proj (bf16, no scales).
|
|
# When loading bf16 into uint8, replace the layer's quant_method
|
|
# with UnquantizedLinearMethod so it runs in bf16.
|
|
# ═══════════════════════════════════════════════════════════════
|
|
old_else_block = ''' else:
|
|
if name not in params_dict:
|
|
# ModelOpt NVFP4 export includes params not in the
|
|
# vllm model (e.g., compressor.position_bias).
|
|
# Skip them silently.
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
continue'''
|
|
|
|
new_else_block = ''' else:
|
|
if name not in params_dict:
|
|
# ModelOpt NVFP4 export includes params not in the
|
|
# vllm model (e.g., compressor.position_bias).
|
|
# Skip them silently.
|
|
continue
|
|
param = params_dict[name]
|
|
|
|
# Handle bf16 → uint8 mismatch for o_a_proj:
|
|
# modelopt didn't quantize o_a_proj (bf16, no scales),
|
|
# but ModelOptNvFp4Config creates wo_a with NVFP4 quant
|
|
# (uint8 weight + scales). When loading bf16 into uint8,
|
|
# we replace the quant method with UnquantizedLinearMethod
|
|
# so the layer runs in bf16 at inference.
|
|
if (name.endswith(".weight")
|
|
and loaded_weight.dtype != torch.uint8
|
|
and param.data.dtype == torch.uint8):
|
|
# Replace this layer's quant method with unquantized
|
|
from vllm.model_executor.layers.linear import (
|
|
UnquantizedLinearMethod,
|
|
)
|
|
parts = name.rsplit(".", 1)
|
|
module_path = parts[0] # e.g., layers.0.attn.wo_a
|
|
# Find the module and override its quant method
|
|
mod = self
|
|
for attr in module_path.split("."):
|
|
if attr.isdigit():
|
|
mod = mod[int(attr)]
|
|
else:
|
|
mod = getattr(mod, attr)
|
|
if hasattr(mod, 'quant_method'):
|
|
mod.quant_method = UnquantizedLinearMethod()
|
|
# Replace the uint8 weight param with bf16
|
|
new_shape = list(loaded_weight.shape)
|
|
new_param = torch.nn.Parameter(
|
|
loaded_weight.clone(), requires_grad=False
|
|
)
|
|
mod.weight = new_param
|
|
# Remove scale params (they'll stay at init values,
|
|
# but the UnquantizedLinearMethod won't use them)
|
|
loaded_params.add(name)
|
|
continue
|
|
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
continue'''
|
|
|
|
if old_else_block in c:
|
|
c = c.replace(old_else_block, new_else_block)
|
|
print("FIX 3 applied: bf16→uint8 mismatch handling for o_a_proj\n")
|
|
else:
|
|
print("FIX 3: Could not find exact else block, trying flexible match...")
|
|
import re
|
|
m = re.search(r'(\s+else:\n\s+if name not in params_dict:.*?continue\n\s+continue)', c, re.DOTALL)
|
|
if m:
|
|
print(f" Found block at position {m.start()}")
|
|
else:
|
|
print(" Could not find else block!")
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.write(c)
|
|
|
|
print("\nAll fixes written to", filepath)
|