#!/usr/bin/env python3 """Update the bf16->uint8 handler to convert bf16->FP8 directly""" filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() # The bf16->uint8 handler needs to convert bf16 weight directly to FP8 # since o_a_proj was NOT quantized by modelopt # Replace the entire handler block old_handler = ''' if (name.endswith(".weight") and loaded_weight.dtype != torch.uint8 and param.data.dtype == torch.uint8): # Quantize bf16 → NVFP4 (E2M1 packed uint8 + scales)''' new_handler = ''' if (name.endswith(".weight") and loaded_weight.dtype != torch.uint8 and param.data.dtype == torch.uint8): # o_a_proj was NOT quantized by modelopt (bf16, no scales) # Convert bf16 → FP8 directly, set weight_scale_inv w_bf16 = loaded_weight w_amax = w_bf16.abs().amax() if w_amax == 0: w_amax = torch.tensor(1.0, device=w_bf16.device) fp8_max = torch.finfo(torch.float8_e4m3fn).max fp8_scale = w_amax / fp8_max w_fp8 = (w_bf16 / fp8_scale).to(torch.float8_e4m3fn) weight_scale_inv = fp8_scale.to(torch.float32) # Load FP8 weight directly (bypass the uint8 param) # Find the module and replace weight + quant method parts = name.rsplit(".", 1) module_path = parts[0] mod = self for attr in module_path.split("."): if attr.isdigit(): mod = mod[int(attr)] else: mod = getattr(mod, attr) # Replace weight param with FP8 version mod.weight = torch.nn.Parameter(w_fp8, requires_grad=False) mod.weight_scale_inv = torch.nn.Parameter( weight_scale_inv.reshape(1), requires_grad=False ) # Switch quant method to FP8 linear from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( Fp8MMQuantMethod, ) mod.quant_method = Fp8LinearMethod(Fp8MMQuantMethod()) # Clean up NVFP4 params for attr in ('weight_scale', 'weight_scale_2', 'input_scale'): if hasattr(mod, attr): delattr(mod, attr) loaded_params.add(name) loaded_params.add(name.replace('.weight', '.weight_scale_inv')) continue # OLD: Quantize bf16 -> NVFP4 (E2M1 packed uint8 + scales)''' c = c.replace(old_handler, new_handler) with open(filepath, 'w') as f: f.write(c) print("Updated bf16->uint8 handler to convert to FP8 directly")