#!/usr/bin/python3 """Fix: move _convert_nvfp4 methods INSIDE DeepseekV4Model class (before hc_head)""" filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() # 1. Remove the wrongly placed methods (currently after hc_head, at top level) # Find the second occurrence of _convert_nvfp4_attention_to_fp8 (the wrongly placed one) idx = c.find(" def _convert_nvfp4_attention_to_fp8(self):\n") if idx > 0: # Find the end of the methods (before class DeepseekV4ForCausalLM) end_marker = "\n\nclass DeepseekV4ForCausalLM(nn.Module):" end_idx = c.find(end_marker, idx) if end_idx > 0: c = c[:idx] + c[end_idx:] print("Removed wrongly placed methods") else: print("Could not find end marker") else: print("No wrongly placed methods found") # 2. Insert the methods inside DeepseekV4Model, right after finalize_mega_moe_weights insert_after = "def finalize_mega_moe_weights(self) -> None:\n for layer in islice(self.layers, self.start_layer, self.end_layer):\n layer.ffn.finalize_mega_moe_moe_weights()\n" # Try a simpler approach: find the end of finalize_mega_moe_weights marker = " layer.ffn.finalize_mega_moe_weights()\n\n\n@torch.compile" if marker in c: methods = ''' layer.ffn.finalize_mega_moe_weights() def _convert_nvfp4_attention_to_fp8(self): E2M1_LUT = torch.tensor( [0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16 ) FP8_MAX = torch.finfo(torch.float8_e4m3fn).max attn_proj_names = {"fused_wqa_wkv", "wq_b", "wo_a", "wo_b"} shared_expert_names = {"gate_up_proj"} converted = 0 for layer_idx, layer in enumerate(self.layers): attn = layer.attn for proj_name in attn_proj_names: if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX) converted += 1 ffn = layer.ffn if hasattr(ffn, "shared_experts"): for proj_name in shared_expert_names: if not hasattr(ffn.shared_experts, proj_name): continue mod = getattr(ffn.shared_experts, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX) converted += 1 if converted > 0: logger.info_once( "Converted %d NVFP4 attention/shared-expert layers to FP8", converted, ) def _convert_nvfp4_module_to_fp8(self, mod, e2m1_lut, fp8_max): w_uint8 = mod.weight.data device = w_uint8.device even_idx = (w_uint8 & 0x0F).int() odd_idx = ((w_uint8 >> 4) & 0x0F).int() even_vals = e2m1_lut.to(device)[even_idx] odd_vals = e2m1_lut.to(device)[odd_idx] w_bf16 = torch.stack([even_vals, odd_vals], dim=-1) w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16) if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): block_scale = mod.weight_scale.data.to(torch.float32) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( -1, -1, block_size ).reshape(w_bf16.shape) else: block_scale_expanded = block_scale global_scale = mod.weight_scale_2.data.max().item() input_scale = ( mod.input_scale.data.max().item() if hasattr(mod, "input_scale") else 1.0 ) w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale w_dequant = w_dequant.to(torch.bfloat16) else: w_dequant = w_bf16 w_amax = w_dequant.abs().amax() if w_amax == 0: w_amax = torch.tensor(1.0, device=device) fp8_scale = w_amax / fp8_max w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn) weight_scale_inv = fp8_scale.to(torch.float32) mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) mod.weight_scale_inv = torch.nn.Parameter( weight_scale_inv.reshape(1), requires_grad=False ) from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale"): if hasattr(mod, attr): delattr(mod, attr) @torch.compile''' c = c.replace(marker, methods) print("Inserted methods inside DeepseekV4Model class") else: print("Could not find insertion marker!") # Try alternate alt = " layer.ffn.finalize_mega_moe_weights()\n\n\n@torch.compile" if alt in c: print("Found alternate marker") else: # Just search for finalize_mega_moe_weights idx = c.find("finalize_mega_moe_weights()") print(f"Found finalize at position {idx}") import ast try: ast.parse(c) print("Syntax OK") except SyntaxError as e: print(f"Syntax error at line {e.lineno}: {e.msg}") with open(filepath, 'w') as f: f.write(c)