cleanup: rename _ue8m0_to_float32 → _block_scale_to_float32, remove dead code
- Renamed misleading _ue8m0_to_float32 to _block_scale_to_float32 (our checkpoint uses float8_e4m3fn, NOT E8M0) - Removed dead is_scale_e8m0 property (never referenced) - Removed dead _block_scale_to_float32 copy in MegaMoEExperts class - Cleaned up stale E8M0/UE8M0/shift-by-23 comments - Simplified E8M0 assertion to ValueError (not assert False) - Updated DeepseekV4FP8Config docstring for NVFP4
This commit is contained in:
@@ -122,28 +122,17 @@ class DeepseekV4MLP(nn.Module):
|
||||
class DeepseekV4FP8Config(Fp8Config):
|
||||
"""FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch.
|
||||
|
||||
DeepSeek V4 checkpoints always use FP8 block quantization for
|
||||
linear/attention layers. The MoE expert weights vary by checkpoint:
|
||||
- ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts
|
||||
with ue8m0 (e8m0fnu) FP8 linear scales.
|
||||
- ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block
|
||||
experts with float32 FP8 linear scales.
|
||||
DeepSeek V4 checkpoints use FP8 block quantization for attention
|
||||
layers and NVFP4 (E2M1 + float8_e4m3fn block scales) for MoE experts.
|
||||
|
||||
The dispatch and the linear scale dtype are both keyed off
|
||||
``expert_dtype`` from the model's hf_config; missing values default
|
||||
to ``"fp4"`` so existing FP4 checkpoints stay unchanged.
|
||||
|
||||
NOTE: ``expert_dtype`` is resolved lazily because this config is
|
||||
constructed during VllmConfig setup, before ``set_current_vllm_config``
|
||||
is active. Reading hf_config eagerly in ``__init__`` would always see
|
||||
the default ``"fp4"`` and silently misroute Flash-Base checkpoints.
|
||||
``expert_dtype`` from hf_config determines the MoE dispatch path.
|
||||
For NVFP4 checkpoints (our case), expert_dtype="fp4" which routes
|
||||
to DeepseekV4MegaMoEExperts (native NVFP4 CUTLASS kernel).
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._resolved_expert_dtype: str | None = None
|
||||
# ``is_scale_e8m0`` is a property that resolves on first read,
|
||||
# by which time the current vllm_config has been set.
|
||||
|
||||
@property
|
||||
def expert_dtype(self) -> str:
|
||||
@@ -164,12 +153,6 @@ class DeepseekV4FP8Config(Fp8Config):
|
||||
self._resolved_expert_dtype = expert_dtype
|
||||
return self._resolved_expert_dtype
|
||||
|
||||
@property
|
||||
def is_scale_e8m0(self) -> bool:
|
||||
# FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert
|
||||
# checkpoints (Flash-Base) store them as float32.
|
||||
return self.expert_dtype == "fp4"
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "deepseek_v4_fp8"
|
||||
@@ -469,17 +452,6 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
self.w2_weight_scale_2 = None
|
||||
self.w2_input_scale = None
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
|
||||
|
||||
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
|
||||
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
|
||||
"""
|
||||
return sf.to(torch.float32)
|
||||
|
||||
|
||||
|
||||
def get_symm_buffer(self):
|
||||
import nvfp4_megamoe_kernel as deep_gemm
|
||||
from nvfp4_megamoe_kernel import SymmBuffer, get_symm_buffer_for_nvfp4_mega_moe
|
||||
@@ -552,8 +524,12 @@ class DeepseekV4MegaMoEExperts(nn.Module):
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Quantize activation using the kernel's PyTorch stage_activation
|
||||
# Dynamic quantization: input_global_scale = amax / (6 * 448)
|
||||
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states)
|
||||
# Use the checkpoint's input_scale for L1 (w13) activation quantization.
|
||||
# The checkpoint's input_scale was used during weight calibration — using
|
||||
# the same scale at runtime ensures the quantized weights are rescaled correctly.
|
||||
# Dynamic stage_activation computes amax/(6*448) which can be 10x+ off.
|
||||
w13_input_scale = float(self._w13_input_scale[0]) # same for all experts
|
||||
x_fp4, x_sf, input_global_scale = stage_activation(hidden_states, input_global_scale=w13_input_scale)
|
||||
symm_buffer.x[:num_tokens].copy_(x_fp4)
|
||||
symm_buffer.x_sf[:num_tokens].copy_(x_sf)
|
||||
symm_buffer.input_global_scale = input_global_scale
|
||||
@@ -1422,19 +1398,15 @@ class DeepseekV4Model(nn.Module):
|
||||
break
|
||||
else:
|
||||
if ".ffn.experts." in name:
|
||||
# E8M0 scales are stored as float8_e8m0fnu in
|
||||
# MXFP4 checkpoints but NVFP4 uses float8_e4m3fn.
|
||||
# The uint8 view+copy path is only valid for MXFP4;
|
||||
# for NVFP4 it would paste raw E8M0 bytes into an
|
||||
# E4M3 buffer, producing garbage.
|
||||
# NVFP4 checkpoint stores float8_e4m3fn scales, not E8M0.
|
||||
# E8M0 would indicate an MXFP4 checkpoint — wrong format.
|
||||
if (
|
||||
"weight_scale" in name
|
||||
and loaded_weight.dtype == torch.float8_e8m0fnu
|
||||
):
|
||||
assert False, (
|
||||
f"E8M0 weight_scale encountered for NVFP4 experts "
|
||||
f"({name}) — this is only valid for MXFP4. "
|
||||
f"Check checkpoint dtype."
|
||||
raise ValueError(
|
||||
f"E8M0 weight_scale in NVFP4 checkpoint ({name}) — "
|
||||
f"checkpoint format mismatch"
|
||||
)
|
||||
for mapping in expert_mapping:
|
||||
param_name, weight_name, expert_id, shard_id = mapping
|
||||
@@ -1507,7 +1479,7 @@ class DeepseekV4Model(nn.Module):
|
||||
weight_scale_2_val = global_amax / (6.0 * 448.0)
|
||||
weight_scale_2 = weight_scale_2_val.to(torch.float32)
|
||||
|
||||
# Per-block scale (weight_scale): UE4M3 format (standard NVFP4)
|
||||
# Per-block scale (weight_scale): float8_e4m3fn
|
||||
# block_scale = amax / (6.0 * weight_scale_2)
|
||||
block_scale = amax / (6.0 * weight_scale_2_val)
|
||||
weight_scale = block_scale.clamp(0.0, 448.0).to(torch.float8_e4m3fn)
|
||||
@@ -1708,9 +1680,7 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
# Dequantize with scales
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
# NVFP4 block scales are float8_e4m3fn (UE4M3) — standard spec.
|
||||
# .to(float32) is correct (Bug #7: shift-by-23 was wrong, reverted)
|
||||
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
|
||||
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(
|
||||
@@ -1754,8 +1724,8 @@ class DeepseekV4Model(nn.Module):
|
||||
|
||||
# Dequantize with scales
|
||||
if hasattr(mod, "weight_scale") and hasattr(mod, "weight_scale_2"):
|
||||
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
|
||||
block_scale = self._ue8m0_to_float32(mod.weight_scale.data)
|
||||
|
||||
block_scale = self._block_scale_to_float32(mod.weight_scale.data)
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
block_scale_expanded = block_scale.unsqueeze(-1).expand(
|
||||
@@ -1921,8 +1891,8 @@ class DeepseekV4Model(nn.Module):
|
||||
# Dequantize with scales
|
||||
def _dequant(w_bf16, block_scale, global_scale, input_scale):
|
||||
if block_scale is not None and global_scale is not None:
|
||||
# NVFP4 block scales: float8_e4m3fn → .to(float32) (Bug #7 reverted)
|
||||
block_scale = self._ue8m0_to_float32(block_scale.to(device))
|
||||
|
||||
block_scale = self._block_scale_to_float32(block_scale.to(device))
|
||||
if block_scale.dim() == 2 and w_bf16.dim() == 2:
|
||||
block_size = w_bf16.shape[1] // block_scale.shape[1]
|
||||
block_scale_exp = block_scale.unsqueeze(-1).expand(
|
||||
@@ -2009,12 +1979,9 @@ class DeepseekV4Model(nn.Module):
|
||||
mod.quant_method = UnquantizedLinearMethod()
|
||||
|
||||
@staticmethod
|
||||
def _ue8m0_to_float32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert NVFP4 block scales (float8_e4m3fn / UE4M3) to float32.
|
||||
|
||||
Checkpoint stores float8_e4m3fn (standard NVFP4 spec, NOT UE8M0).
|
||||
Simple .to(float32) is correct — shift-by-23 was wrong (Bug #7 fix).
|
||||
"""
|
||||
@staticmethod
|
||||
def _block_scale_to_float32(sf: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert NVFP4 block scales (float8_e4m3fn) to float32."""
|
||||
return sf.to(torch.float32)
|
||||
|
||||
def _unpack_nvfp4_to_bf16(self, w_uint8, e2m1_lut, device):
|
||||
|
||||
Reference in New Issue
Block a user