#!/usr/bin/python3 """Fix: Only convert wo_a and fused_wqa_wkv to FP8 (used with fp8_einsum). Keep wq_b, wo_b, gate_up_proj in NVFP4 (used via normal .forward()).""" filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() # Fix 1: Change the list of projections to convert old_proj_names = 'attn_proj_names = {"fused_wqa_wkv", "wq_b", "wo_a", "wo_b"}' new_proj_names = 'attn_proj_names = {"fused_wqa_wkv", "wo_a"} # Only these use fp8_einsum' c = c.replace(old_proj_names, new_proj_names) # Fix 2: Remove shared_experts gate_up_proj from conversion old_shared = ''' shared_expert_names = {"gate_up_proj"} converted = 0 for layer_idx, layer in enumerate(self.layers): attn = layer.attn for proj_name in attn_proj_names: if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX) converted += 1 ffn = layer.ffn if hasattr(ffn, "shared_experts"): for proj_name in shared_expert_names: if not hasattr(ffn.shared_experts, proj_name): continue mod = getattr(ffn.shared_experts, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX) converted += 1''' new_shared = ''' converted = 0 for layer_idx, layer in enumerate(self.layers): attn = layer.attn for proj_name in attn_proj_names: if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX) converted += 1 # wq_b, wo_b, gate_up_proj stay in NVFP4 (normal .forward() path)''' c = c.replace(old_shared, new_shared) with open(filepath, 'w') as f: f.write(c) import ast try: ast.parse(c) print("Syntax OK") except SyntaxError as e: print(f"Syntax error: {e}") print("Updated: only fused_wqa_wkv and wo_a converted to FP8")