Files
deepseek-v4-quant/tmp/fix_selective_fp8.py
biondizzle 02b8ea536f Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge.
All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
2026-05-11 02:02:49 +00:00

65 lines
2.5 KiB
Python

#!/usr/bin/python3
"""Fix: Only convert wo_a and fused_wqa_wkv to FP8 (used with fp8_einsum).
Keep wq_b, wo_b, gate_up_proj in NVFP4 (used via normal .forward())."""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
# Fix 1: Change the list of projections to convert
old_proj_names = 'attn_proj_names = {"fused_wqa_wkv", "wq_b", "wo_a", "wo_b"}'
new_proj_names = 'attn_proj_names = {"fused_wqa_wkv", "wo_a"} # Only these use fp8_einsum'
c = c.replace(old_proj_names, new_proj_names)
# Fix 2: Remove shared_experts gate_up_proj from conversion
old_shared = ''' shared_expert_names = {"gate_up_proj"}
converted = 0
for layer_idx, layer in enumerate(self.layers):
attn = layer.attn
for proj_name in attn_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
converted += 1
ffn = layer.ffn
if hasattr(ffn, "shared_experts"):
for proj_name in shared_expert_names:
if not hasattr(ffn.shared_experts, proj_name):
continue
mod = getattr(ffn.shared_experts, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
converted += 1'''
new_shared = ''' converted = 0
for layer_idx, layer in enumerate(self.layers):
attn = layer.attn
for proj_name in attn_proj_names:
if not hasattr(attn, proj_name):
continue
mod = getattr(attn, proj_name)
if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8:
continue
self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX)
converted += 1
# wq_b, wo_b, gate_up_proj stay in NVFP4 (normal .forward() path)'''
c = c.replace(old_shared, new_shared)
with open(filepath, 'w') as f:
f.write(c)
import ast
try:
ast.parse(c)
print("Syntax OK")
except SyntaxError as e:
print(f"Syntax error: {e}")
print("Updated: only fused_wqa_wkv and wo_a converted to FP8")