From 294e9f98f26119ec23fd4e0f36fc514cd4b2672d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 16 May 2026 01:55:56 +0000 Subject: [PATCH] =?UTF-8?q?cleanup:=20rename=20=5Fue8m0=5Fto=5Ffloat32=20?= =?UTF-8?q?=E2=86=92=20=5Fblock=5Fscale=5Fto=5Ffloat32,=20remove=20dead=20?= =?UTF-8?q?code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Renamed misleading _ue8m0_to_float32 to _block_scale_to_float32 (our checkpoint uses float8_e4m3fn, NOT E8M0) - Removed dead is_scale_e8m0 property (never referenced) - Removed dead _block_scale_to_float32 copy in MegaMoEExperts class - Cleaned up stale E8M0/UE8M0/shift-by-23 comments - Simplified E8M0 assertion to ValueError (not assert False) - Updated DeepseekV4FP8Config docstring for NVFP4 --- vllm/patches/deepseek_v4.py | 83 +++++++++++-------------------------- 1 file changed, 25 insertions(+), 58 deletions(-) diff --git a/vllm/patches/deepseek_v4.py b/vllm/patches/deepseek_v4.py index 580987e0..d965dac0 100644 --- a/vllm/patches/deepseek_v4.py +++ b/vllm/patches/deepseek_v4.py @@ -122,28 +122,17 @@ class DeepseekV4MLP(nn.Module): class DeepseekV4FP8Config(Fp8Config): """FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch. - DeepSeek V4 checkpoints always use FP8 block quantization for - linear/attention layers. The MoE expert weights vary by checkpoint: - - ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts - with ue8m0 (e8m0fnu) FP8 linear scales. - - ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block - experts with float32 FP8 linear scales. + DeepSeek V4 checkpoints use FP8 block quantization for attention + layers and NVFP4 (E2M1 + float8_e4m3fn block scales) for MoE experts. - The dispatch and the linear scale dtype are both keyed off - ``expert_dtype`` from the model's hf_config; missing values default - to ``"fp4"`` so existing FP4 checkpoints stay unchanged. - - NOTE: ``expert_dtype`` is resolved lazily because this config is - constructed during VllmConfig setup, before ``set_current_vllm_config`` - is active. Reading hf_config eagerly in ``__init__`` would always see - the default ``"fp4"`` and silently misroute Flash-Base checkpoints. + ``expert_dtype`` from hf_config determines the MoE dispatch path. + For NVFP4 checkpoints (our case), expert_dtype="fp4" which routes + to DeepseekV4MegaMoEExperts (native NVFP4 CUTLASS kernel). """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._resolved_expert_dtype: str | None = None - # ``is_scale_e8m0`` is a property that resolves on first read, - # by which time the current vllm_config has been set. @property def expert_dtype(self) -> str: @@ -164,12 +153,6 @@ class DeepseekV4FP8Config(Fp8Config): self._resolved_expert_dtype = expert_dtype return self._resolved_expert_dtype - @property - def is_scale_e8m0(self) -> bool: - # FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert - # checkpoints (Flash-Base) store them as float32. - return self.expert_dtype == "fp4" - @classmethod def get_name(cls) -> QuantizationMethods: return "deepseek_v4_fp8" @@ -469,17 +452,6 @@ class DeepseekV4MegaMoEExperts(nn.Module): self.w2_weight_scale_2 = None self.w2_input_scale = None - @staticmethod - def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor: - """Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32. - - Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0). - Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix). - """ - return sf.to(torch.float32) - - - def get_symm_buffer(self): import nvfp4_megamoe_kernel as deep_gemm from nvfp4_megamoe_kernel import SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe @@ -552,8 +524,12 @@ class DeepseekV4MegaMoEExperts(nn.Module): num_tokens = hidden_states.shape[0] # Quantize activation using the kernel's PyTorch stage_activation - # Dynamic quantization: input_global_scale = amax / (6 * 448) - x_fp4, x_sf, input_global_scale = stage_activation(hidden_states) + # Use the checkpoint's input_scale for L1 (w13) activation quantization. + # The checkpoint's input_scale was used during weight calibration — using + # the same scale at runtime ensures the quantized weights are rescaled correctly. + # Dynamic stage_activation computes amax/(6*448) which can be 10x+ off. + w13_input_scale = float(self._w13_input_scale[0]) # same for all experts + x_fp4, x_sf, input_global_scale = stage_activation(hidden_states, input_global_scale=w13_input_scale) symm_buffer.x[:num_tokens].copy_(x_fp4) symm_buffer.x_sf[:num_tokens].copy_(x_sf) symm_buffer.input_global_scale = input_global_scale @@ -1422,19 +1398,15 @@ class DeepseekV4Model(nn.Module): break else: if ".ffn.experts." in name: - # E8M0 scales are stored as float8_e8m0fnu in - # MXFP4 checkpoints but NVFP4 uses float8_e4m3fn. - # The uint8 view+copy path is only valid for MXFP4; - # for NVFP4 it would paste raw E8M0 bytes into an - # E4M3 buffer, producing garbage. + # NVFP4 checkpoint stores float8_e4m3fn scales, not E8M0. + # E8M0 would indicate an MXFP4 checkpoint — wrong format. if ( "weight_scale" in name and loaded_weight.dtype == torch.float8_e8m0fnu ): - assert False, ( - f"E8M0 weight_scale encountered for NVFP4 experts " - f"({name}) — this is only valid for MXFP4. " - f"Check checkpoint dtype." + raise ValueError( + f"E8M0 weight_scale in NVFP4 checkpoint ({name}) — " + f"checkpoint format mismatch" ) for mapping in expert_mapping: param_name, weight_name, expert_id, shard_id = mapping @@ -1507,7 +1479,7 @@ class DeepseekV4Model(nn.Module): weight_scale_2_val = global_amax / (6.0 * 448.0) weight_scale_2 = weight_scale_2_val.to(torch.float32) - # Per-block scale (weight_scale): UE4M3 format (standard NVFP4) + # Per-block scale (weight_scale): float8_e4m3fn # block_scale = amax / (6.0 * weight_scale_2) block_scale = amax / (6.0 * weight_scale_2_val) weight_scale = block_scale.clamp(0.0, 448.0).to(torch.float8_e4m3fn) @@ -1708,9 +1680,7 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - # NVFP4 block scales are float8_e4m3fn (UE4M3) — standard spec. - # .to(float32) is correct (Bug #7: shift-by-23 was wrong, reverted) - block_scale = self._ue8m0_to_float32(mod.weight_scale.data) + block_scale = self._block_scale_to_float32(mod.weight_scale.data) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( @@ -1754,8 +1724,8 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"): - # NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted) - block_scale = self._ue8m0_to_float32(mod.weight_scale.data) + + block_scale = self._block_scale_to_float32(mod.weight_scale.data) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_expanded = block_scale.unsqueeze(-1).expand( @@ -1921,8 +1891,8 @@ class DeepseekV4Model(nn.Module): # Dequantize with scales def _dequant(w_bf16, block_scale, global_scale, input_scale): if block_scale is not None and global_scale is not None: - # NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted) - block_scale = self._ue8m0_to_float32(block_scale.to(device)) + + block_scale = self._block_scale_to_float32(block_scale.to(device)) if block_scale.dim() == 2 and w_bf16.dim() == 2: block_size = w_bf16.shape[1] // block_scale.shape[1] block_scale_exp = block_scale.unsqueeze(-1).expand( @@ -2009,12 +1979,9 @@ class DeepseekV4Model(nn.Module): mod.quant_method = UnquantizedLinearMethod() @staticmethod - def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor: - """Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32. - - Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0). - Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix). - """ + @staticmethod + def _block_scale_to_float32(sf: torch.Tensor) -> torch.Tensor: + """Convert NVFP4 block scales (float8_e4m3fn) to float32.""" return sf.to(torch.float32) def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):