#!/usr/bin/python3 """ Apply ALL fixes to the S11 base version of deepseek_v4.py. This is a clean application of all fixes we've developed. """ filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() import ast def check(c, label): try: ast.parse(c) print(f" {label}: OK") return True except SyntaxError as e: print(f" {label}: SYNTAX ERROR at line {e.lineno}: {e.msg}") return False # ═══════════════════════════════════════════════════════════ # FIX 1: Substr mapping — remove .mla_attn. from attn projections # ═══════════════════════════════════════════════════════════ subs = { '".self_attn.q_a_proj.": ".attn.mla_attn.wq_a."': '".self_attn.q_a_proj.": ".attn.wq_a."', '".self_attn.q_b_proj.": ".attn.mla_attn.wq_b."': '".self_attn.q_b_proj.": ".attn.wq_b."', '".self_attn.q_a_norm.": ".attn.mla_attn.q_norm."': '".self_attn.q_a_norm.": ".attn.q_norm."', '".self_attn.o_a_proj.": ".attn.mla_attn.wo_a."': '".self_attn.o_a_proj.": ".attn.wo_a."', '".self_attn.o_b_proj.": ".attn.mla_attn.wo_b."': '".self_attn.o_b_proj.": ".attn.wo_b."', '".self_attn.sinks": ".attn.mla_attn.attn_sink"': '".self_attn.sinks": ".attn.attn_sink"', '".self_attn.kv_proj.": ".attn.mla_attn.wkv."': '".self_attn.kv_proj.": ".attn.wkv."', '".self_attn.kv_norm.": ".attn.mla_attn.kv_norm."': '".self_attn.kv_norm.": ".attn.kv_norm."', } for old, new in subs.items(): c = c.replace(old, new) check(c, "Fix 1 (substr)") # ═══════════════════════════════════════════════════════════ # FIX 2: Skip patterns — only skip compressor scales # ═══════════════════════════════════════════════════════════ # Remove attention and shared expert scale skip patterns lines_to_remove = [ ' re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale$"): None,', ' re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale$"): None,', ' re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale$"): None,', ' re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale$"): None,', ' re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale$"): None,', ' re.compile(r"\\.self_attn\\.kv_proj\\.weight_scale_2$"): None,', ' re.compile(r"\\.self_attn\\.q_a_proj\\.weight_scale_2$"): None,', ' re.compile(r"\\.self_attn\\.q_b_proj\\.weight_scale_2$"): None,', ' re.compile(r"\\.self_attn\\.o_a_proj\\.weight_scale_2$"): None,', ' re.compile(r"\\.self_attn\\.o_b_proj\\.weight_scale_2$"): None,', ' re.compile(r"\\.self_attn\\.kv_proj\\.input_scale$"): None,', ' re.compile(r"\\.self_attn\\.q_a_proj\\.input_scale$"): None,', ' re.compile(r"\\.self_attn\\.q_b_proj\\.input_scale$"): None,', ' re.compile(r"\\.self_attn\\.o_a_proj\\.input_scale$"): None,', ' re.compile(r"\\.self_attn\\.o_b_proj\\.input_scale$"): None,', ' re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale$"): None,', ' re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale$"): None,', ' re.compile(r"\\.shared_experts\\.gate_proj\\.weight_scale_2$"): None,', ' re.compile(r"\\.shared_experts\\.up_proj\\.weight_scale_2$"): None,', ' re.compile(r"\\.shared_experts\\.gate_proj\\.input_scale$"): None,', ' re.compile(r"\\.shared_experts\\.up_proj\\.input_scale$"): None,', ] for line in lines_to_remove: c = c.replace(line + "\n", "") c = c.replace(line, "") check(c, "Fix 2 (skip patterns)") # ═══════════════════════════════════════════════════════════ # FIX 3: Remove the 'head.weight' suffix mapping that causes # 'lm_head.weight' to become 'lm_lm_head.weight' # ═══════════════════════════════════════════════════════════ c = c.replace(' "head.weight": "lm_head.weight",\n', '') check(c, "Fix 3 (suffix)") # ═══════════════════════════════════════════════════════════ # FIX 4: Handle o_a_proj bf16 -> FP8 at load time # modelopt didn't quantize o_a_proj, but vLLM creates wo_a with NVFP4 # Convert bf16 -> FP8 and set weight_scale_inv # ═══════════════════════════════════════════════════════════ old_else = ''' else: if name not in params_dict: # ModelOpt NVFP4 export includes params not in the # vllm model (e.g., compressor.position_bias). # Skip them silently. continue param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) continue''' new_else = ''' else: if name not in params_dict: continue param = params_dict[name] # Handle o_a_proj bf16 -> wo_a uint8 mismatch if (name.endswith(".weight") and loaded_weight.dtype != torch.uint8 and param.data.dtype == torch.uint8): # o_a_proj was NOT quantized by modelopt (bf16, no scales) # Convert bf16 -> FP8 and set weight_scale_inv w_bf16 = loaded_weight w_amax = w_bf16.abs().amax() if w_amax == 0: w_amax = torch.tensor(1.0, device=w_bf16.device) fp8_max = torch.finfo(torch.float8_e4m3fn).max fp8_scale = w_amax / fp8_max w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn) weight_scale_inv = fp8_scale.to(torch.float32) # Replace the module weight and add weight_scale_inv parts = name.rsplit(".", 1) module_path = parts[0] mod = self for attr in module_path.split("."): if attr.isdigit(): mod = mod[int(attr)] else: mod = getattr(mod, attr) mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) mod.weight_scale_inv = torch.nn.Parameter( weight_scale_inv.reshape(1), requires_grad=False ) from vllm.model_executor.layers.linear import ( UnquantizedLinearMethod, ) mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale"): if hasattr(mod, attr): delattr(mod, attr) loaded_params.add(name) loaded_params.add(name.replace(".weight", ".weight_scale_inv")) continue weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) continue''' c = c.replace(old_else, new_else) check(c, "Fix 4 (o_a_proj bf16->FP8)") # ═══════════════════════════════════════════════════════════ # FIX 5: Add NVFP4->FP8 post-load conversion for attention # This converts all uint8 NVFP4 attention weights to FP8 # ═══════════════════════════════════════════════════════════ conversion_methods = ''' def _convert_nvfp4_attention_to_fp8(self): E2M1_LUT = torch.tensor( [0, 0.5, 1, 1.5, 2, 3, 4, 6], dtype=torch.bfloat16 ) FP8_MAX = torch.finfo(torch.float8_e4m3fn).max attn_proj_names = {"fused_wqa_wkv", "wq_b", "wo_a", "wo_b"} shared_expert_names = {"gate_up_proj"} converted = 0 for layer_idx, layer in enumerate(self.layers): attn = layer.attn for proj_name in attn_proj_names: if not hasattr(attn, proj_name): continue mod = getattr(attn, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX) converted += 1 ffn = layer.ffn if hasattr(ffn, "shared_experts"): for proj_name in shared_expert_names: if not hasattr(ffn.shared_experts, proj_name): continue mod = getattr(ffn.shared_experts, proj_name) if not hasattr(mod, "weight") or mod.weight.dtype != torch.uint8: continue self._convert_nvfp4_module_to_fp8(mod, E2M1_LUT, FP8_MAX) converted += 1 if converted > 0: logger.info_once( "Converted %d NVFP4 attention/shared-expert layers to FP8", converted, ) def _convert_nvfp4_module_to_fp8(self, mod, e2m1_lut, fp8_max): w_uint8 = mod.weight.data device = w_uint8.device even_idx = (w_uint8 & 0x0F).int() odd_idx = ((w_uint8 >> 4) & 0x0F).int() even_vals = e2m1_lut.to(device)[even_idx] odd_vals = e2m1_lut.to(device)[odd_idx] w_bf16 = torch.stack([even_vals, odd_vals], dim=-1) w_bf16 = w_bf16.reshape(w_uint8.shape[0], -1).to(torch.bfloat16) if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): block_scale = mod.weight_scale.data.to(torch.float32) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( -1, -1, block_size ).reshape(w_bf16.shape) else: block_scale_expanded = block_scale global_scale = mod.weight_scale_2.data.max().item() input_scale = ( mod.input_scale.data.max().item() if hasattr(mod, "input_scale") else 1.0 ) w_dequant = w_bf16.float() * block_scale_expanded * global_scale * input_scale w_dequant = w_dequant.to(torch.bfloat16) else: w_dequant = w_bf16 w_amax = w_dequant.abs().amax() if w_amax == 0: w_amax = torch.tensor(1.0, device=device) fp8_scale = w_amax / fp8_max w_fp8 = (w_dequant / fp8_scale).to(torch.float8_e4m3fn) weight_scale_inv = fp8_scale.to(torch.float32) mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) mod.weight_scale_inv = torch.nn.Parameter( weight_scale_inv.reshape(1), requires_grad=False ) from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() for attr in ("weight_scale", "weight_scale_2", "input_scale"): if hasattr(mod, attr): delattr(mod, attr) ''' # Insert before DeepseekV4ForCausalLM class marker = "\n\nclass DeepseekV4ForCausalLM(nn.Module):" if marker in c: c = c.replace(marker, "\n" + conversion_methods + "\nclass DeepseekV4ForCausalLM(nn.Module):") print(" Fix 5: Inserted conversion methods") else: print(" Fix 5: Could not find class marker") check(c, "Fix 5 (NVFP4->FP8 methods)") # ═══════════════════════════════════════════════════════════ # FIX 6: Call the conversion from DeepseekV4ForCausalLM.load_weights # ═══════════════════════════════════════════════════════════ old_load = " self.model.finalize_mega_moe_weights()\n return loaded_params" new_load = " self.model.finalize_mega_moe_weights()\n self.model._convert_nvfp4_attention_to_fp8()\n return loaded_params" c = c.replace(old_load, new_load) check(c, "Fix 6 (call conversion)") # ═══════════════════════════════════════════════════════════ # Final validation # ═══════════════════════════════════════════════════════════ check(c, "FINAL") with open(filepath, 'w') as f: f.write(c) print("All fixes applied!")