Server running on B200 port 8000 with full NVFP4→vLLM bridge. All critical bugs fixed: DeepGEMM scale format, compressor shapes, block scale values.
54 lines
2.2 KiB
Python
54 lines
2.2 KiB
Python
#!/usr/bin/python3
|
|
"""Fix the E2M1 unpacking in _convert_nvfp4_module_to_fp8"""
|
|
|
|
filepath = "/root/nvidia-meeting/deepseek-v4-quant/patches/deepseek_v4.py"
|
|
|
|
with open(filepath, 'r') as f:
|
|
c = f.read()
|
|
|
|
# Fix the unpacking code in _convert_nvfp4_module_to_fp8
|
|
old_unpack = ''' even_idx = (w_uint8 & 0x0F).int()
|
|
odd_idx = ((w_uint8 >> 4) & 0x0F).int()
|
|
even_vals = e2m1_lut.to(device)[even_idx]
|
|
odd_vals = e2m1_lut.to(device)[odd_idx]'''
|
|
|
|
new_unpack = ''' # Extract 4-bit FP4 values (0-15, bit 3 = sign)
|
|
even_raw = (w_uint8 & 0x0F).int()
|
|
odd_raw = ((w_uint8 >> 4) & 0x0F).int()
|
|
# Sign: 0-7 = positive, 8-15 = negative
|
|
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
|
|
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
|
|
# Magnitude index: lower 3 bits (0-7)
|
|
even_vals = even_sign * e2m1_lut.to(device)[even_raw & 0x07]
|
|
odd_vals = odd_sign * e2m1_lut.to(device)[odd_raw & 0x07]'''
|
|
|
|
c = c.replace(old_unpack, new_unpack)
|
|
print("Fixed E2M1 unpacking in _convert_nvfp4_module_to_fp8")
|
|
|
|
# Also fix the E2M1 unpacking in the stacked params code
|
|
old_stacked_unpack = ''' even_idx = (loaded_weight & 0x0F).int()
|
|
odd_idx = ((loaded_weight >> 4) & 0x0F).int()
|
|
even_vals = E2M1_LUT[even_idx]
|
|
odd_vals = E2M1_LUT[odd_idx]'''
|
|
|
|
new_stacked_unpack = ''' # Extract 4-bit FP4 values (0-15, bit 3 = sign)
|
|
even_raw = (loaded_weight & 0x0F).int()
|
|
odd_raw = ((loaded_weight >> 4) & 0x0F).int()
|
|
even_sign = torch.where(even_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
|
|
odd_sign = torch.where(odd_raw >= 8, -1.0, 1.0).to(torch.bfloat16)
|
|
even_vals = even_sign * E2M1_LUT[even_raw & 0x07]
|
|
odd_vals = odd_sign * E2M1_LUT[odd_raw & 0x07]'''
|
|
|
|
c = c.replace(old_stacked_unpack, new_stacked_unpack)
|
|
print("Fixed E2M1 unpacking in stacked params code")
|
|
|
|
with open(filepath, 'w') as f:
|
|
f.write(c)
|
|
|
|
import ast
|
|
try:
|
|
ast.parse(c)
|
|
print("Syntax OK")
|
|
except SyntaxError as e:
|
|
print(f"Syntax error: {e}")
|