fix CRITICAL #7: UE8M0 block scale misinterpreted as E4M3
scale_fmt=ue8m0 means weight_scale bytes are E8M0 format (power-of-2 only). A simple .to(float32) misinterprets them as E4M3 (which has mantissa bits), producing completely wrong block scale values and garbled output. Fix: add _ue8m0_to_float32() that reinterprets raw uint8 bits as IEEE 754 exponent field: (raw_byte << 23).view(float32) = 2^(raw-127). Applied to: - _dequant_nvfp4_to_bf16 (BF16 layers: fused_wqa_wkv, wq_b, wo_b) - _convert_nvfp4_to_fp8 (wo_a FP8 conversion) - _reconstruct_compressor_weight (compressor fused_wkv_wgate) - BF16->FP4 quantization path (stores as UE8M0, reads back correctly)
This commit is contained in:
@@ -29,6 +29,7 @@ print(f"""
|
||||
#4 compressor indexer — sub_path for .indexer keys
|
||||
#5 block scale dtype — must be float32, not float8_e4m3fn
|
||||
#6 block scale values — torch.full(fp8_scale) not torch.ones
|
||||
#7 UE8M0 block scale — .to(float32) misinterprets E8M0 as E4M3
|
||||
{'='*70}
|
||||
""")
|
||||
# ==============================================================================
|
||||
@@ -1535,12 +1536,16 @@ class DeepseekV4Model(nn.Module):
|
||||
weight_scale_2_val = global_amax / (6.0 * 448.0)
|
||||
weight_scale_2 = weight_scale_2_val.to(torch.float32)
|
||||
|
||||
# Per-block scale (weight_scale): fp8 e4m3
|
||||
# block_scale = amax / (6.0 * weight_scale_2)
|
||||
# Per-block scale (weight_scale): UE8M0 format
|
||||
# scale_fmt=ue8m0: block_scale = 2^(exp-127), stored as
|
||||
# uint8 exponent byte viewed as float8_e4m3fn
|
||||
block_scale = amax / (6.0 * weight_scale_2_val)
|
||||
# Clamp to fp8 e4m3 range and cast
|
||||
block_scale = block_scale.clamp(min=0, max=448.0)
|
||||
weight_scale = block_scale.to(torch.float8_e4m3fn)
|
||||
# Convert to UE8M0: floor to nearest power of 2
|
||||
# UE8M0 exponent = floor(log2(block_scale)) + 127
|
||||
block_scale_clamped = block_scale.clamp(min=2**-127)
|
||||
block_scale_exp = torch.floor(torch.log2(block_scale_clamped)).to(torch.int32) + 127
|
||||
block_scale_exp = block_scale_exp.clamp(0, 254).to(torch.uint8)
|
||||
weight_scale = block_scale_exp.view(torch.float8_e4m3fn)
|
||||
|
||||
# Quantize to FP4 (E2M1)
|
||||
# E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive)
|
||||
@@ -1548,8 +1553,8 @@ class DeepseekV4Model(nn.Module):
|
||||
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0],
|
||||
dtype=torch.float32, device=w_bf16.device,
|
||||
)
|
||||
# For each block, dequantize the block scale from fp8
|
||||
block_scale_f32 = weight_scale.to(torch.float32)
|
||||
# For each block, dequantize the block scale from UE8M0
|
||||
block_scale_f32 = (block_scale_exp.to(torch.int32) << 23).view(torch.float32)
|
||||
# Scale the weight values: normalized = w / (block_scale * weight_scale_2)
|
||||
# We need to find the nearest FP4 value
|
||||
scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val)
|
||||
@@ -1743,7 +1748,10 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
# Dequantize with scales
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
block_scale = mod.weight_scale.data.to(torch.float32)
|
||||
# scale_fmt=ue8m0: weight_scale bytes are E8M0 format (power-of-2 only).
|
||||
# A simple .to(float32) misinterprets them as E4M3. Must reinterpret
|
||||
# the raw uint8 bits as IEEE 754 exponent field.
|
||||
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(
|
||||
@@ -1785,7 +1793,8 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
# Dequantize with scales
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
block_scale = mod.weight_scale.data.to(torch.float32)
|
||||
# scale_fmt=ue8m0: reinterpret E8M0 bytes as float32
|
||||
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(
|
||||
@@ -1939,7 +1948,8 @@ class DeepseekV4Model(nn.Module):
|
||||
# Dequantize with scales
|
||||
def _dequant(w_bf16, block_scale, global_scale, input_scale):
|
||||
if block_scale is not None and global_scale is not None:
|
||||
block_scale = block_scale.to(device).to(torch.float32)
|
||||
# scale_fmt=ue8m0: reinterpret E8M0 bytes as float32
|
||||
block_scale = self._ue8m0_to_float32(block_scale.to(device))
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
block_scale_exp = block_scale.unsqueeze(-1).expand(
|
||||
@@ -2028,6 +2038,20 @@ class DeepseekV4Model(nn.Module):
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert UE8M0 (E8M0 power-of-2) scale bytes to float32.
|
||||
|
||||
NVFP4 checkpoints with scale_fmt=ue8m0 store per-block weight scales as
|
||||
E8M0 format (8-bit exponent, no mantissa). The value = 2^(raw_byte - 127).
|
||||
The bytes are loaded as float8_e4m3fn by safetensors, but a simple
|
||||
.to(float32) misinterprets them as E4M3 (which has mantissa bits).
|
||||
Correct conversion: place the raw uint8 bits into the exponent field
|
||||
of an IEEE 754 float32 (bits 23-30), yielding 2^(raw-127) * implicit_1.
|
||||
"""
|
||||
raw_uint8 = sf.view(torch.uint8)
|
||||
return (raw_uint8.to(torch.int32) << 23).view(torch.float32)
|
||||
|
||||
def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):
|
||||
"""Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format."""
|
||||
# Extract 4-bit FP4 values (0-15, bit 3 = sign)
|
||||
|
||||
Reference in New Issue
Block a user