#!/usr/bin/env python3 """Fix the bf16→uint8 handler to properly quantize to NVFP4 instead of switching to UnquantizedLinearMethod""" filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() old_handler = ''' # Handle bf16 → uint8 mismatch for o_a_proj: # modelopt didn't quantize o_a_proj (bf16, no scales), # but ModelOptNvFp4Config creates wo_a with NVFP4 quant # (uint8 weight + scales). When loading bf16 into uint8, # we replace the quant method with UnquantizedLinearMethod # so the layer runs in bf16 at inference. if (name.endswith(".weight") and loaded_weight.dtype != torch.uint8 and param.data.dtype == torch.uint8): # Replace this layer's quant method with unquantized from vllm.model_executor.layers.linear import ( UnquantizedLinearMethod, ) parts = name.rsplit(".", 1) module_path = parts[0] # e.g., layers.0.attn.wo_a # Find the module and override its quant method mod = self for attr in module_path.split("."): if attr.isdigit(): mod = mod[int(attr)] else: mod = getattr(mod, attr) if hasattr(mod, 'quant_method'): mod.quant_method = UnquantizedLinearMethod() # Replace the uint8 weight param with bf16 new_param = torch.nn.Parameter( loaded_weight.clone(), requires_grad=False ) mod.weight = new_param # Set weight_scale_inv = 1.0 (required by # DeepseekV4MLAModules forward pass which # reads wo_a.weight_scale_inv directly) mod.weight_scale_inv = torch.nn.Parameter( torch.tensor(1.0, dtype=torch.float32), requires_grad=False, ) # Also set input_scale to prevent missing attr errors if hasattr(mod, 'input_scale'): mod.input_scale = torch.nn.Parameter( torch.tensor(1.0, dtype=torch.float32), requires_grad=False, ) loaded_params.add(name) loaded_params.add(name.replace('.weight', '.weight_scale_inv')) continue''' new_handler = ''' # Handle bf16 → uint8 mismatch for o_a_proj: # modelopt didn't quantize o_a_proj (bf16, no scales), # but ModelOptNvFp4Config creates wo_a with NVFP4 quant # (uint8 weight + scales). We quantize the bf16 weight # to NVFP4 at load time so the layer runs in NVFP4 path. if (name.endswith(".weight") and loaded_weight.dtype != torch.uint8 and param.data.dtype == torch.uint8): # Quantize bf16 → NVFP4 (E2M1 packed uint8 + scales) w_bf16 = loaded_weight out_dim, in_dim = w_bf16.shape block_size = 16 assert in_dim % block_size == 0 n_blocks = in_dim // block_size # Reshape into blocks w_blocks = w_bf16.reshape(out_dim, n_blocks, block_size) # Compute per-block amax amax = w_blocks.abs().amax(dim=-1) # [out, n_blocks] # Global scale (weight_scale_2): max amax / (6.0 * 448.0) global_amax = amax.max() # Use 448.0 as the max e4m3 value for scale computation weight_scale_2_val = global_amax / (6.0 * 448.0) weight_scale_2 = weight_scale_2_val.to(torch.float32) # Per-block scale (weight_scale): fp8 e4m3 # block_scale = amax / (6.0 * weight_scale_2) block_scale = amax / (6.0 * weight_scale_2_val) # Clamp to fp8 e4m3 range and cast block_scale = block_scale.clamp(min=0, max=448.0) weight_scale = block_scale.to(torch.float8_e4m3fn) # Quantize to FP4 (E2M1) # E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive) FP4_POS = torch.tensor( [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32, device=w_bf16.device, ) # For each block, dequantize the block scale from fp8 block_scale_f32 = weight_scale.to(torch.float32) # Scale the weight values: normalized = w / (block_scale * weight_scale_2) # We need to find the nearest FP4 value scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val) # Find nearest FP4 index (0-7 for magnitude) # Use absolute value for matching, then apply sign scaled_abs = scaled.abs() # Find closest FP4 value diff = (scaled_abs.unsqueeze(-1) - FP4_POS).abs() fp4_idx = diff.argmin(dim=-1) # [out, n_blocks, block_size] # Apply sign: negative values get bit 3 set sign = (scaled < 0).int() fp4_val = (sign << 3) | fp4_idx.int() # Pack: 2 FP4 values per uint8 byte # Even positions → lower nibble, Odd → upper nibble fp4_flat = fp4_val.reshape(out_dim, -1) # [out, in_dim] assert fp4_flat.shape[1] % 2 == 0 even = fp4_flat[:, 0::2] # lower nibble odd = fp4_flat[:, 1::2] # upper nibble packed = (odd << 4) | even weight_packed = packed.to(torch.uint8) # Reshape weight_scale to [out, n_blocks] weight_scale_2d = weight_scale.reshape(out_dim, n_blocks) # Load the quantized weight into the uint8 param weight_loader = param.weight_loader weight_loader(param, weight_packed) loaded_params.add(name) # Load scales into sibling params base = name.rsplit(".", 1)[0] # weight_scale ws_name = f"{base}.weight_scale" if ws_name in params_dict: ws_param = params_dict[ws_name] ws_loader = getattr(ws_param, "weight_loader", default_weight_loader) ws_loader(ws_param, weight_scale_2d) loaded_params.add(ws_name) # weight_scale_2 ws2_name = f"{base}.weight_scale_2" if ws2_name in params_dict: ws2_param = params_dict[ws2_name] ws2_loader = getattr(ws2_param, "weight_loader", default_weight_loader) ws2_loader(ws2_param, weight_scale_2.reshape(1)) loaded_params.add(ws2_name) # input_scale: use 1.0 default (dynamic quant) is_name = f"{base}.input_scale" if is_name in params_dict: is_param = params_dict[is_name] is_loader = getattr(is_param, "weight_loader", default_weight_loader) is_loader(is_param, torch.tensor(1.0, dtype=torch.float32)) loaded_params.add(is_name) continue''' if old_handler in c: c = c.replace(old_handler, new_handler) print('FIX 5 applied: Replaced UnquantizedLinearMethod with proper NVFP4 quantization') else: print('FIX 5: Could not find exact handler block, trying flexible match...') if 'UnquantizedLinearMethod' in c: print(' Found UnquantizedLinearMethod in code - manual fix needed') else: print(' UnquantizedLinearMethod not found - already replaced?') with open(filepath, 'w') as f: f.write(c)