Files
deepseek-v4-quant/tmp/apply_fixes.py

191 lines
10 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Apply all NVFP4 serving fixes to deepseek_v4.py"""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
# ═══════════════════════════════════════════════════════════════
# FIX 1: Substr mapping — remove .mla_attn. from attention projections
# The model has fused_wqa_wkv, wq_b, wo_a, wo_b at attn.* level
# ═══════════════════════════════════════════════════════════════
replacements_1 = {
'".self_attn.q_a_proj.": ".attn.mla_attn.wq_a."': '".self_attn.q_a_proj.": ".attn.wq_a."',
'".self_attn.q_b_proj.": ".attn.mla_attn.wq_b."': '".self_attn.q_b_proj.": ".attn.wq_b."',
'".self_attn.q_a_norm.": ".attn.mla_attn.q_norm."': '".self_attn.q_a_norm.": ".attn.q_norm."',
'".self_attn.o_a_proj.": ".attn.mla_attn.wo_a."': '".self_attn.o_a_proj.": ".attn.wo_a."',
'".self_attn.o_b_proj.": ".attn.mla_attn.wo_b."': '".self_attn.o_b_proj.": ".attn.wo_b."',
'".self_attn.sinks": ".attn.mla_attn.attn_sink"': '".self_attn.sinks": ".attn.attn_sink"',
'".self_attn.kv_proj.": ".attn.mla_attn.wkv."': '".self_attn.kv_proj.": ".attn.wkv."',
'".self_attn.kv_norm.": ".attn.mla_attn.kv_norm."': '".self_attn.kv_norm.": ".attn.kv_norm."',
}
for old, new in replacements_1.items():
if old in c:
c = c.replace(old, new)
print(f" Fixed: {old[:50]}... → {new[:50]}...")
else:
print(f" NOT FOUND: {old[:60]}...")
# Update comment
c = c.replace(
'# Attention: self_attn → attn.mla_attn',
'# Attention: self_attn → attn (projections at attn level, not mla_attn)'
)
print("FIX 1 applied: substr mappings updated\n")
# ═══════════════════════════════════════════════════════════════
# FIX 2: Skip patterns — only skip compressor scale tensors
# Attention and shared expert scale tensors now correctly load
# ═══════════════════════════════════════════════════════════════
old_skip_block = ''' fused_skip_regex = {
# Compressor projections → fused_wkv_wgate (stacked)
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale$"): None,
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale$"): None,
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale_2$"): None,
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale_2$"): None,
re.compile(r"\\.compressor\\.kv_proj\\.input_scale$"): None,
re.compile(r"\\.compressor\\.gate_proj\\.input_scale$"): None,
# Attention projections → fused_wqa_wkv (stacked)
re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale$"): None,
re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale$"): None,
re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale$"): None,
re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale$"): None,
re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale$"): None,
re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale_2$"): None,
re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale_2$"): None,
re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale_2$"): None,
re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale_2$"): None,
re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale_2$"): None,
re.compile(r"\\.self_attn\\.kv_proj\\.input_scale$"): None,
re.compile(r"\\.self_attn\\.q_a_proj\\.input_scale$"): None,
re.compile(r"\\.self_attn\\.q_b_proj\\.input_scale$"): None,
re.compile(r"\\.self_attn\\.o_a_proj\\.input_scale$"): None,
re.compile(r"\\.self_attn\\.o_b_proj\\.input_scale$"): None,
# Shared expert gate_proj/up_proj → gate_up_proj (stacked)
re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale$"): None,
re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale$"): None,
re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale_2$"): None,
re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale_2$"): None,
re.compile(r"\\.shared_experts\\.gate_proj\\.input_scale$"): None,
re.compile(r"\\.shared_experts\\.up_proj\\.input_scale$"): None,
}'''
new_skip_block = ''' fused_skip_regex = {
# Compressor projections → fused_wkv_wgate (stacked)
# Compressor uses UnquantizedLinearMethod (quant_config=None),
# so it only has a bf16 weight param — no scale params registered.
# We unpack the NVFP4 uint8 weights to bf16 at load time.
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale$"): None,
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale$"): None,
re.compile(r"\\.compressor\\.kv_proj\\.weight_scale_2$"): None,
re.compile(r"\\.compressor\\.gate_proj\\.weight_scale_2$"): None,
re.compile(r"\\.compressor\\.kv_proj\\.input_scale$"): None,
re.compile(r"\\.compressor\\.gate_proj\\.input_scale$"): None,
# Note: attention and shared expert scale tensors are NO LONGER
# skipped. After fixing substr mappings, they correctly map to the
# model's NVFP4 scale parameters (fused_wqa_wkv, wq_b, wo_a,
# wo_b, gate_up_proj). They load via the stacking logic.
}'''
if old_skip_block in c:
c = c.replace(old_skip_block, new_skip_block)
print("FIX 2 applied: skip patterns updated (only compressor scales skipped)\n")
else:
print("FIX 2: Could not find old skip block, searching for alternatives...")
# Try a more flexible search
import re
# Find the fused_skip_regex block
m = re.search(r' fused_skip_regex = \{[^}]+\}', c)
if m:
print(f" Found block at position {m.start()}")
else:
print(" Could not find fused_skip_regex block!")
# ═══════════════════════════════════════════════════════════════
# FIX 3: Handle o_a_proj bf16 → wo_a uint8 mismatch
# modelopt didn't quantize o_a_proj (bf16, no scales).
# When loading bf16 into uint8, replace the layer's quant_method
# with UnquantizedLinearMethod so it runs in bf16.
# ═══════════════════════════════════════════════════════════════
old_else_block = ''' else:
if name not in params_dict:
# ModelOpt NVFP4 export includes params not in the
# vllm model (e.g., compressor.position_bias).
# Skip them silently.
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue'''
new_else_block = ''' else:
if name not in params_dict:
# ModelOpt NVFP4 export includes params not in the
# vllm model (e.g., compressor.position_bias).
# Skip them silently.
continue
param = params_dict[name]
# Handle bf16 → uint8 mismatch for o_a_proj:
# modelopt didn't quantize o_a_proj (bf16, no scales),
# but ModelOptNvFp4Config creates wo_a with NVFP4 quant
# (uint8 weight + scales). When loading bf16 into uint8,
# we replace the quant method with UnquantizedLinearMethod
# so the layer runs in bf16 at inference.
if (name.endswith(".weight")
and loaded_weight.dtype != torch.uint8
and param.data.dtype == torch.uint8):
# Replace this layer's quant method with unquantized
from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
)
parts = name.rsplit(".", 1)
module_path = parts[0] # e.g., layers.0.attn.wo_a
# Find the module and override its quant method
mod = self
for attr in module_path.split("."):
if attr.isdigit():
mod = mod[int(attr)]
else:
mod = getattr(mod, attr)
if hasattr(mod, 'quant_method'):
mod.quant_method = UnquantizedLinearMethod()
# Replace the uint8 weight param with bf16
new_shape = list(loaded_weight.shape)
new_param = torch.nn.Parameter(
loaded_weight.clone(), requires_grad=False
)
mod.weight = new_param
# Remove scale params (they'll stay at init values,
# but the UnquantizedLinearMethod won't use them)
loaded_params.add(name)
continue
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue'''
if old_else_block in c:
c = c.replace(old_else_block, new_else_block)
print("FIX 3 applied: bf16→uint8 mismatch handling for o_a_proj\n")
else:
print("FIX 3: Could not find exact else block, trying flexible match...")
import re
m = re.search(r'(\s+else:\n\s+if name not in params_dict:.*?continue\n\s+continue)', c, re.DOTALL)
if m:
print(f" Found block at position {m.start()}")
else:
print(" Could not find else block!")
with open(filepath, 'w') as f:
f.write(c)
print("\nAll fixes written to", filepath)