cleanup: rename _ue8m0_to_float32 → _block_scale_to_float32, remove dead code

- Renamed misleading _ue8m0_to_float32 to _block_scale_to_float32
  (our checkpoint uses float8_e4m3fn, NOT E8M0)
- Removed dead is_scale_e8m0 property (never referenced)
- Removed dead _block_scale_to_float32 copy in MegaMoEExperts class
- Cleaned up stale E8M0/UE8M0/shift-by-23 comments
- Simplified E8M0 assertion to ValueError (not assert False)
- Updated DeepseekV4FP8Config docstring for NVFP4
This commit is contained in:
2026-05-16 01:55:56 +00:00
parent 4a624879ca
commit 294e9f98f2

View File

@@ -122,28 +122,17 @@ class DeepseekV4MLP(nn.Module):
class DeepseekV4FP8Config(Fp8Config):
"""FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch.
DeepSeek V4 checkpoints always use FP8 block quantization for
linear/attention layers. The MoE expert weights vary by checkpoint:
- ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts
with ue8m0 (e8m0fnu) FP8 linear scales.
- ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block
experts with float32 FP8 linear scales.
DeepSeek V4 checkpoints use FP8 block quantization for attention
layers and NVFP4 (E2M1 + float8_e4m3fn block scales) for MoE experts.
The dispatch and the linear scale dtype are both keyed off
``expert_dtype`` from the model's hf_config; missing values default
to ``"fp4"`` so existing FP4 checkpoints stay unchanged.
NOTE: ``expert_dtype`` is resolved lazily because this config is
constructed during VllmConfig setup, before ``set_current_vllm_config``
is active. Reading hf_config eagerly in ``__init__`` would always see
the default ``"fp4"`` and silently misroute Flash-Base checkpoints.
``expert_dtype`` from hf_config determines the MoE dispatch path.
For NVFP4 checkpoints (our case), expert_dtype="fp4" which routes
to DeepseekV4MegaMoEExperts (native NVFP4 CUTLASS kernel).
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._resolved_expert_dtype: str | None = None
# ``is_scale_e8m0`` is a property that resolves on first read,
# by which time the current vllm_config has been set.
@property
def expert_dtype(self) -> str:
@@ -164,12 +153,6 @@ class DeepseekV4FP8Config(Fp8Config):
self._resolved_expert_dtype = expert_dtype
return self._resolved_expert_dtype
@property
def is_scale_e8m0(self) -> bool:
# FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert
# checkpoints (Flash-Base) store them as float32.
return self.expert_dtype == "fp4"
@classmethod
def get_name(cls) -> QuantizationMethods:
return "deepseek_v4_fp8"
@@ -469,17 +452,6 @@ class DeepseekV4MegaMoEExperts(nn.Module):
self.w2_weight_scale_2 = None
self.w2_input_scale = None
@staticmethod
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
"""
return sf.to(torch.float32)
def get_symm_buffer(self):
import nvfp4_megamoe_kernel as deep_gemm
from nvfp4_megamoe_kernel import SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe
@@ -552,8 +524,12 @@ class DeepseekV4MegaMoEExperts(nn.Module):
num_tokens = hidden_states.shape[0]
# Quantize activation using the kernel's PyTorch stage_activation
# Dynamic quantization: input_global_scale = amax / (6 * 448)
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
# Use the checkpoint's input_scale for L1 (w13) activation quantization.
# The checkpoint's input_scale was used during weight calibration — using
# the same scale at runtime ensures the quantized weights are rescaled correctly.
# Dynamic stage_activation computes amax/(6*448) which can be 10x+ off.
w13_input_scale = float(self._w13_input_scale[0]) # same for all experts
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states, input_global_scale=w13_input_scale)
symm_buffer.x[:num_tokens].copy_(x_fp4)
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
symm_buffer.input_global_scale = input_global_scale
@@ -1422,19 +1398,15 @@ class DeepseekV4Model(nn.Module):
break
else:
if ".ffn.experts." in name:
# E8M0 scales are stored as float8_e8m0fnu in
# MXFP4 checkpoints but NVFP4 uses float8_e4m3fn.
# The uint8 view+copy path is only valid for MXFP4;
# for NVFP4 it would paste raw E8M0 bytes into an
# E4M3 buffer, producing garbage.
# NVFP4 checkpoint stores float8_e4m3fn scales, not E8M0.
# E8M0 would indicate an MXFP4 checkpoint — wrong format.
if (
"weight_scale" in name
and loaded_weight.dtype == torch.float8_e8m0fnu
):
assert False, (
f"E8M0 weight_scale encountered for NVFP4 experts "
f"({name}) — this is only valid for MXFP4. "
f"Check checkpoint dtype."
raise ValueError(
f"E8M0 weight_scale in NVFP4 checkpoint ({name}) — "
f"checkpoint format mismatch"
)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
@@ -1507,7 +1479,7 @@ class DeepseekV4Model(nn.Module):
weight_scale_2_val = global_amax / (6.0 * 448.0)
weight_scale_2 = weight_scale_2_val.to(torch.float32)
# Per-block scale (weight_scale): UE4M3 format (standard NVFP4)
# Per-block scale (weight_scale): float8_e4m3fn
# block_scale = amax / (6.0 * weight_scale_2)
block_scale = amax / (6.0 * weight_scale_2_val)
weight_scale = block_scale.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
@@ -1708,9 +1680,7 @@ class DeepseekV4Model(nn.Module):
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
# NVFP4 block scales are float8_e4m3fn (UE4M3) — standard spec.
# .to(float32) is correct (Bug #7: shift-by-23 was wrong, reverted)
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
@@ -1754,8 +1724,8 @@ class DeepseekV4Model(nn.Module):
# Dequantize with scales
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_expanded = block_scale.unsqueeze(-1).expand(
@@ -1921,8 +1891,8 @@ class DeepseekV4Model(nn.Module):
# Dequantize with scales
def _dequant(w_bf16, block_scale, global_scale, input_scale):
if block_scale is not None and global_scale is not None:
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
block_scale = self._ue8m0_to_float32(block_scale.to(device))
block_scale = self._block_scale_to_float32(block_scale.to(device))
if block_scale.dim() == 2 and w_bf16.dim() == 2:
block_size = w_bf16.shape[1] // block_scale.shape[1]
block_scale_exp = block_scale.unsqueeze(-1).expand(
@@ -2009,12 +1979,9 @@ class DeepseekV4Model(nn.Module):
mod.quant_method = UnquantizedLinearMethod()
@staticmethod
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
"""
@staticmethod
def _block_scale_to_float32(sf: torch.Tensor) -> torch.Tensor:
"""Convert NVFP4 block scales (float8_e4m3fn) to float32."""
return sf.to(torch.float32)
def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):