Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
65 lines
2.5 KiB
Python
65 lines
2.5 KiB
Python
#!/usr/bin/python3
|
|
"""Fix: Only convert wo_a and fused_wqa_wkv to FP8 (used with fp8_einsum).
|
|
Keep wq_b, wo_b, gate_up_proj in NVFP4 (used via normal .forward())."""
|
|
|
|
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
|
|
|
with open(filepath, 'r') as f:
|
|
c = f.read()
|
|
|
|
# Fix 1: Change the list of projections to convert
|
|
old_proj_names = 'attn_proj_names = {"fused_wqa_wkv", "wq_b", "wo_a", "wo_b"}'
|
|
new_proj_names = 'attn_proj_names = {"fused_wqa_wkv", "wo_a"} # Only these use fp8_einsum'
|
|
c = c.replace(old_proj_names, new_proj_names)
|
|
|
|
# Fix 2: Remove shared_experts gate_up_proj from conversion
|
|
old_shared = ''' shared_expert_names = {"gate_up_proj"}
|
|
converted = 0
|
|
for layer_idx, layer in enumerate(self.layers):
|
|
attn = layer.attn
|
|
for proj_name in attn_proj_names:
|
|
if not hasattr(attn, proj_name):
|
|
continue
|
|
mod = getattr(attn, proj_name)
|
|
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
|
continue
|
|
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
|
converted += 1
|
|
ffn = layer.ffn
|
|
if hasattr(ffn, "shared_experts"):
|
|
for proj_name in shared_expert_names:
|
|
if not hasattr(ffn.shared_experts, proj_name):
|
|
continue
|
|
mod = getattr(ffn.shared_experts, proj_name)
|
|
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
|
continue
|
|
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
|
converted += 1'''
|
|
|
|
new_shared = ''' converted = 0
|
|
for layer_idx, layer in enumerate(self.layers):
|
|
attn = layer.attn
|
|
for proj_name in attn_proj_names:
|
|
if not hasattr(attn, proj_name):
|
|
continue
|
|
mod = getattr(attn, proj_name)
|
|
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
|
|
continue
|
|
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
|
|
converted += 1
|
|
# wq_b, wo_b, gate_up_proj stay in NVFP4 (normal .forward() path)'''
|
|
|
|
c = c.replace(old_shared, new_shared)
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.write(c)
|
|
|
|
import ast
|
|
try:
|
|
ast.parse(c)
|
|
print("Syntax OK")
|
|
except SyntaxError as e:
|
|
print(f"Syntax error: {e}")
|
|
|
|
print("Updated: only fused_wqa_wkv and wo_a converted to FP8")
|