diff --git a/patches/deepseek_v4.py b/patches/deepseek_v4.py index 8b8a805..3663efc 100644 --- a/patches/deepseek_v4.py +++ b/patches/deepseek_v4.py @@ -29,6 +29,7 @@ print(f""" #4 compressor indexer — sub_path for .indexer keys #5 block scale dtype — must be float32, not float8_e4m3fn #6 block scale values — torch.full(fp8_scale) not torch.ones + #7 UE8M0 block scale — .to(float32) misinterprets E8M0 as E4M3 {'='*70} """) # ============================================================================== @@ -1535,12 +1536,16 @@ class DeepseekV4Model(nn.Module): weight_scale_2_val = global_amax / (6.0 * 448.0) weight_scale_2 = weight_scale_2_val.to(torch.float32) - # Per-block scale (weight_scale): fp8 e4m3 - # block_scale = amax / (6.0 * weight_scale_2) + # Per-block scale (weight_scale): UE8M0 format + # scale_fmt=ue8m0: block_scale = 2^(exp-127), stored as + # uint8 exponent byte viewed as float8_e4m3fn block_scale = amax / (6.0 * weight_scale_2_val) - # Clamp to fp8 e4m3 range and cast - block_scale = block_scale.clamp(min=0, max=448.0) - weight_scale = block_scale.to(torch.float8_e4m3fn) + # Convert to UE8M0: floor to nearest power of 2 + # UE8M0 exponent = floor(log2(block_scale)) + 127 + block_scale_clamped = block_scale.clamp(min=2**-127) + block_scale_exp = torch.floor(torch.log2(block_scale_clamped)).to(torch.int32) + 127 + block_scale_exp = block_scale_exp.clamp(0, 254).to(torch.uint8) + weight_scale = block_scale_exp.view(torch.float8_e4m3fn) # Quantize to FP4 (E2M1) # E2M1 LUT: 0, 0.5, 1, 1.5, 2, 3, 4, 6 (positive) @@ -1548,8 +1553,8 @@ class DeepseekV4Model(nn.Module): [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32, device=w_bf16.device, ) - # For each block, dequantize the block scale from fp8 - block_scale_f32 = weight_scale.to(torch.float32) + # For each block, dequantize the block scale from UE8M0 + block_scale_f32 = (block_scale_exp.to(torch.int32) << 23).view(torch.float32) # Scale the weight values: normalized = w / (block_scale * weight_scale_2) # We need to find the nearest FP4 value scaled = w_blocks / (block_scale_f32.unsqueeze(-1) * weight_scale_2_val) @@ -1743,7 +1748,10 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - block_scale = mod.weight_scale.data.to(torch.float32) + # scale_fmt=ue8m0: weight_scale bytes are E8M0 format (power-of-2 only). + # A simple .to(float32) misinterprets them as E4M3. Must reinterpret + # the raw uint8 bits as IEEE 754 exponent field. + block_scale = self._ue8m0_to_float32(mod.weight_scale.data) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( @@ -1785,7 +1793,8 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - block_scale = mod.weight_scale.data.to(torch.float32) + # scale_fmt=ue8m0: reinterpret E8M0 bytes as float32 + block_scale = self._ue8m0_to_float32(mod.weight_scale.data) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( @@ -1939,7 +1948,8 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales def _dequant(w_bf16, block_scale, global_scale, input_scale): if block_scale is not None and global_scale is not None: - block_scale = block_scale.to(device).to(torch.float32) + # scale_fmt=ue8m0: reinterpret E8M0 bytes as float32 + block_scale = self._ue8m0_to_float32(block_scale.to(device)) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_exp = block_scale.unsqueeze(-1).expand( @@ -2028,6 +2038,20 @@ class DeepseekV4Model(nn.Module): from vllm.model_executor.layers.linear import UnquantizedLinearMethod mod.quant_method = UnquantizedLinearMethod() + @staticmethod + def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor: + """Convert UE8M0 (E8M0 power-of-2) scale bytes to float32. + + NVFP4 checkpoints with scale_fmt=ue8m0 store per-block weight scales as + E8M0 format (8-bit exponent, no mantissa). The value = 2^(raw_byte - 127). + The bytes are loaded as float8_e4m3fn by safetensors, but a simple + .to(float32) misinterprets them as E4M3 (which has mantissa bits). + Correct conversion: place the raw uint8 bits into the exponent field + of an IEEE 754 float32 (bits 23-30), yielding 2^(raw-127) * implicit_1. + """ + raw_uint8 = sf.view(torch.uint8) + return (raw_uint8.to(torch.int32) << 23).view(torch.float32) + def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device): """Unpack NVFP4 uint8 packed weights to bf16 using E2M1 format.""" # Extract 4-bit FP4 values (0-15, bit 3 = sign)