Files
deepseek-v4-quant/tmp/fix_e2m1.py
biondizzle 02b8ea536f Update MEMORY.md and memory files with vLLM NVFP4 serving progress
Server running on B200 port 8000 with full NVFP4→vLLM bridge.
All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
2026-05-11 02:02:49 +00:00

54 lines
2.2 KiB
Python

#!/usr/bin/python3
"""Fix the E2M1 unpacking in _convert_nvfp4_module_to_fp8"""
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
with open(filepath, 'r') as f:
c = f.read()
# Fix the unpacking code in _convert_nvfp4_module_to_fp8
old_unpack = ''' even_idx = (w_uint8 & 0x0F).int()
odd_idx = ((w_uint8 >> 4) & 0x0F).int()
even_vals = e2m1_lut.to(device)[even_idx]
odd_vals = e2m1_lut.to(device)[odd_idx]'''
new_unpack = ''' # Extract 4-bit FP4 values (0-15, bit 3 = sign)
even_raw = (w_uint8 & 0x0F).int()
odd_raw = ((w_uint8 >> 4) & 0x0F).int()
# Sign: 0-7 = positive, 8-15 = negative
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
# Magnitude index: lower 3 bits (0-7)
even_vals = even_sign * e2m1_lut.to(device)[even_raw & 0x07]
odd_vals = odd_sign * e2m1_lut.to(device)[odd_raw & 0x07]'''
c = c.replace(old_unpack, new_unpack)
print("Fixed E2M1 unpacking in _convert_nvfp4_module_to_fp8")
# Also fix the E2M1 unpacking in the stacked params code
old_stacked_unpack = ''' even_idx = (loaded_weight & 0x0F).int()
odd_idx = ((loaded_weight >> 4) & 0x0F).int()
even_vals = E2M1_LUT[even_idx]
odd_vals = E2M1_LUT[odd_idx]'''
new_stacked_unpack = ''' # Extract 4-bit FP4 values (0-15, bit 3 = sign)
even_raw = (loaded_weight & 0x0F).int()
odd_raw = ((loaded_weight >> 4) & 0x0F).int()
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
even_vals = even_sign * E2M1_LUT[even_raw & 0x07]
odd_vals = odd_sign * E2M1_LUT[odd_raw & 0x07]'''
c = c.replace(old_stacked_unpack, new_stacked_unpack)
print("Fixed E2M1 unpacking in stacked params code")
with open(filepath, 'w') as f:
f.write(c)
import ast
try:
ast.parse(c)
print("Syntax OK")
except SyntaxError as e:
print(f"Syntax error: {e}")