#!/usr/bin/python3 """Fix the E2M1 unpacking in _convert_nvfp4_module_to_fp8""" filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py" with open(filepath, 'r') as f: c = f.read() # Fix the unpacking code in _convert_nvfp4_module_to_fp8 old_unpack = ''' even_idx = (w_uint8 & 0x0F).int() odd_idx = ((w_uint8 >> 4) & 0x0F).int() even_vals = e2m1_lut.to(device)[even_idx] odd_vals = e2m1_lut.to(device)[odd_idx]''' new_unpack = ''' # Extract 4-bit FP4 values (0-15, bit 3 = sign) even_raw = (w_uint8 & 0x0F).int() odd_raw = ((w_uint8 >> 4) & 0x0F).int() # Sign: 0-7 = positive, 8-15 = negative even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16) odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16) # Magnitude index: lower 3 bits (0-7) even_vals = even_sign * e2m1_lut.to(device)[even_raw & 0x07] odd_vals = odd_sign * e2m1_lut.to(device)[odd_raw & 0x07]''' c = c.replace(old_unpack, new_unpack) print("Fixed E2M1 unpacking in _convert_nvfp4_module_to_fp8") # Also fix the E2M1 unpacking in the stacked params code old_stacked_unpack = ''' even_idx = (loaded_weight & 0x0F).int() odd_idx = ((loaded_weight >> 4) & 0x0F).int() even_vals = E2M1_LUT[even_idx] odd_vals = E2M1_LUT[odd_idx]''' new_stacked_unpack = ''' # Extract 4-bit FP4 values (0-15, bit 3 = sign) even_raw = (loaded_weight & 0x0F).int() odd_raw = ((loaded_weight >> 4) & 0x0F).int() even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16) odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16) even_vals = even_sign * E2M1_LUT[even_raw & 0x07] odd_vals = odd_sign * E2M1_LUT[odd_raw & 0x07]''' c = c.replace(old_stacked_unpack, new_stacked_unpack) print("Fixed E2M1 unpacking in stacked params code") with open(filepath, 'w') as f: f.write(c) import ast try: ast.parse(c) print("Syntax OK") except SyntaxError as e: print(f"Syntax error: {e}")