Files
nvfp4-megamoe-kernel/vllm/patches/deepseek_v4.py
biondizzle 10c14ddb49 Fix NVFP4 mapper: layer norms, hc params, indexer path, q_a_norm
- input_layernorm → attn_norm, post_attention_layernorm → ffn_norm
- hc_head.fn/base/scale → hc_head_fn/base/scale
- attn_hc/ffn_hc → hc_attn/hc_ffn (dot to underscore)
- q_a_norm → q_norm, sinks → attn_sink
- Indexer params: self_attn.compressor.indexer → attn.indexer
  (not attn.mla_attn.compressor.indexer)
2026-05-19 00:24:26 +00:00

1925 lines
71 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import typing
from collections.abc import Callable, Iterable
from itertools import islice
import regex as re
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import (
get_ep_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp
from vllm.model_executor.layers.deepseek_v4_attention import (
DeepseekV4Indexer,
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
)
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsPP
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
WeightsMapper,
extract_layer_index,
is_pp_missing_parameter,
make_layers,
maybe_prefix,
)
_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8")
class DeepseekV4MLP(nn.Module):
def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: str,
swiglu_limit: float | None = None,
quant_config: QuantizationConfig | None = None,
reduce_results: bool = True,
is_sequence_parallel: bool = False,
prefix: str = "",
) -> None:
super().__init__()
# If is_sequence_parallel, the input and output tensors are sharded
# across the ranks within the tp_group. In this case the weights are
# replicated and no collective ops are needed.
# Otherwise we use standard TP with an allreduce at the end.
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size,
[intermediate_size] * 2,
bias=False,
quant_config=quant_config,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results,
disable_tp=is_sequence_parallel,
prefix=f"{prefix}.down_proj",
)
if hidden_act != "silu":
raise ValueError(
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
)
if swiglu_limit is not None:
self.act_fn = SiluAndMulWithClamp(swiglu_limit)
else:
self.act_fn = SiluAndMul()
def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(x)
return x
class DeepseekV4FP8Config(Fp8Config):
"""FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch.
DeepSeek V4 checkpoints always use FP8 block quantization for
linear/attention layers. The MoE expert weights vary by checkpoint:
- ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts
with ue8m0 (e8m0fnu) FP8 linear scales.
- ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block
experts with float32 FP8 linear scales.
The dispatch and the linear scale dtype are both keyed off
``expert_dtype`` from the model's hf_config; missing values default
to ``"fp4"`` so existing FP4 checkpoints stay unchanged.
NOTE: ``expert_dtype`` is resolved lazily because this config is
constructed during VllmConfig setup, before ``set_current_vllm_config``
is active. Reading hf_config eagerly in ``__init__`` would always see
the default ``"fp4"`` and silently misroute Flash-Base checkpoints.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._resolved_expert_dtype: str | None = None
# ``is_scale_e8m0`` is a property that resolves on first read,
# by which time the current vllm_config has been set.
@property
def expert_dtype(self) -> str:
if self._resolved_expert_dtype is None:
try:
hf_config = get_current_vllm_config().model_config.hf_config
except Exception:
# vllm_config not yet set; defer the decision until a
# later call lands inside set_current_vllm_config.
return "fp4"
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
raise ValueError(
f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; "
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
)
self._resolved_expert_dtype = expert_dtype
from vllm.logger import init_logger
init_logger(__name__).info_once(
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
)
return self._resolved_expert_dtype
@property
def is_scale_e8m0(self) -> bool:
# FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert
# checkpoints (Flash-Base) store them as float32.
return self.expert_dtype == "fp4"
@classmethod
def get_name(cls) -> QuantizationMethods:
return "deepseek_v4_fp8"
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
if not (
isinstance(hf_quant_cfg, dict)
and hf_quant_cfg.get("quant_method") in ("fp8", "deepseek_v4_fp8")
):
return None
model_type = getattr(hf_config, "model_type", None)
if model_type == "deepseek_v4" or user_quant == "deepseek_v4_fp8":
return "deepseek_v4_fp8"
return None
def get_quant_method(self, layer, prefix):
if isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
if self.expert_dtype == "fp4":
return Mxfp4MoEMethod(layer.moe_config)
# expert_dtype == "fp8": fall through to Fp8Config which
# returns Fp8MoEMethod with block-wise float32 scales.
return super().get_quant_method(layer, prefix)
def is_mxfp4_quant(self, prefix, layer):
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"
@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_stride_m: tl.constexpr,
hidden_stride_k: tl.constexpr,
x_stride_m: tl.constexpr,
x_stride_k: tl.constexpr,
x_sf_stride_m: tl.constexpr,
x_sf_stride_k: tl.constexpr,
topk_ids_stride_m: tl.constexpr,
topk_ids_stride_k: tl.constexpr,
topk_weights_stride_m: tl.constexpr,
topk_weights_stride_k: tl.constexpr,
topk_idx_stride_m: tl.constexpr,
topk_idx_stride_k: tl.constexpr,
topk_weights_out_stride_m: tl.constexpr,
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_K: tl.constexpr,
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
k_block_id = tl.program_id(1)
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
k_mask = k_offsets < hidden_size
hidden = tl.load(
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
mask=k_mask,
other=0.0,
).to(tl.float32)
num_groups: tl.constexpr = BLOCK_K // GROUP_K
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(hidden_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)
scale = amax / 448.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
tl.uint32
)
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
tl.store(
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
fp8,
mask=k_mask,
)
scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,
)
if k_block_id == 0:
topk_offsets = tl.arange(0, BLOCK_TOPK)
topk_mask = topk_offsets < top_k
ids = tl.load(
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
mask=topk_mask,
other=0,
).to(tl.int64)
tl.store(
topk_idx_out
+ token_id * topk_idx_stride_m
+ topk_offsets * topk_idx_stride_k,
ids,
mask=topk_mask,
)
weights = tl.load(
topk_weights
+ token_id * topk_weights_stride_m
+ topk_offsets * topk_weights_stride_k,
mask=topk_mask,
other=0.0,
)
tl.store(
topk_weights_out
+ token_id * topk_weights_out_stride_m
+ topk_offsets * topk_weights_out_stride_k,
weights,
mask=topk_mask,
)
def _stage_deepseek_v4_mega_moe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp8: torch.Tensor,
x_sf: torch.Tensor,
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
num_tokens, hidden_size = hidden_states.shape
if num_tokens == 0:
return
if hidden_size % 128 != 0:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
"a multiple of 128."
)
top_k = topk_ids.shape[1]
if topk_weights.shape != topk_ids.shape:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
"topk_ids to have the same shape."
)
block_k = 128
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
block_topk = triton.next_power_of_2(top_k)
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp8.stride(0),
x_fp8.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),
topk_ids.stride(1),
topk_weights.stride(0),
topk_weights.stride(1),
topk_idx_out.stride(0),
topk_idx_out.stride(1),
topk_weights_out.stride(0),
topk_weights_out.stride(1),
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=32,
BLOCK_TOPK=block_topk,
num_warps=4,
)
def make_deepseek_v4_expert_params_mapping(
num_experts: int,
) -> list[tuple[str, str, int, str]]:
return [
(
"experts.w13_" if shard_id in ("w1", "w3") else "experts.w2_",
f"experts.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in [
("w1", "w1"),
("w2", "w2"),
("w3", "w3"),
]
]
class DeepseekV4MegaMoEExperts(nn.Module):
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}
def __init__(
self,
vllm_config: VllmConfig,
*,
num_experts: int,
num_local_experts: int,
experts_start_idx: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
prefix: str = "",
):
super().__init__()
self.prefix = prefix
self.num_experts = num_experts
self.num_local_experts = num_local_experts
self.experts_start_idx = experts_start_idx
self.experts_end_idx = experts_start_idx + num_local_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
weight_attrs = {"weight_loader": self.weight_loader}
self.w13_weight = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight, weight_attrs)
self.w13_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 32,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight_scale, weight_attrs)
self.w13_weight_scale.quant_method = "block"
self.w2_weight = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight, weight_attrs)
self.w2_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 32,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight_scale, weight_attrs)
self.w2_weight_scale.quant_method = "block"
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
# Register in the static forward context so the custom-op wrapper
# can look up this module by name from within a torch.compile graph.
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def _map_global_expert_id(self, expert_id: int) -> int:
if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx:
return -1
return expert_id - self.experts_start_idx
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False,
) -> bool | None:
local_expert_id = self._map_global_expert_id(expert_id)
if local_expert_id == -1:
return False if return_success else None
expert_data = param.data[local_expert_id]
if shard_id in ("w1", "w3"):
if "w13_" not in weight_name:
return False if return_success else None
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
elif shard_id == "w2":
if "w2_" not in weight_name:
return False if return_success else None
else:
raise ValueError(f"Unsupported expert shard id: {shard_id}")
if expert_data.shape != loaded_weight.shape:
raise ValueError(
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
f"vs checkpoint {tuple(loaded_weight.shape)}"
)
expert_data.copy_(loaded_weight)
return True if return_success else None
@staticmethod
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
return (sf.to(torch.int32) << 23).view(torch.float32)
def _check_runtime_supported(self) -> None:
if not torch.cuda.is_available():
raise NotImplementedError("DeepSeek V4 MegaMoE requires CUDA.")
device = self.w13_weight.device
if device.type != "cuda":
raise NotImplementedError(
"DeepSeek V4 MegaMoE expert weights must be loaded on CUDA."
)
if torch.cuda.get_device_capability(device)[0] != 10:
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
raise ValueError(
"DeepGEMM MegaMoE requires hidden and intermediate sizes "
"to be multiples of 128."
)
def finalize_weights(self) -> None:
if self._transformed_l1_weights is not None:
return
self._check_runtime_supported()
import vllm.third_party.deep_gemm as deep_gemm
w13_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
2 * self.intermediate_size,
self.hidden_size,
(1, 32),
self.num_local_experts,
)
w2_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(),
self.hidden_size,
self.intermediate_size,
(1, 32),
self.num_local_experts,
)
self._transformed_l1_weights, self._transformed_l2_weights = (
deep_gemm.transform_weights_for_mega_moe(
(self.w13_weight.data.view(torch.int8).contiguous(), w13_scale),
(self.w2_weight.data.view(torch.int8).contiguous(), w2_scale),
)
)
# Drop the original loader-side parameters: the MegaMoE kernels only
# consume the transformed views above. transform_weights_for_mega_moe
# allocates a fresh tensor for the L1 weight (see _interleave_l1_weights)
# and fresh SF tensors for L1/L2; the L2 weight is the only tensor that
# aliases the original storage, and _transformed_l2_weights still holds
# it, so the storage stays live after we drop the Parameter.
self.w13_weight = None
self.w13_weight_scale = None
self.w2_weight = None
self.w2_weight_scale = None
def get_symm_buffer(self):
import vllm.third_party.deep_gemm as deep_gemm
group = get_ep_group().device_group
device = torch.accelerator.current_device_index()
key = (
id(group),
device,
self.num_experts,
self.max_num_tokens,
self.top_k,
self.hidden_size,
self.intermediate_size,
)
symm_buffer = self._symm_buffer_cache.get(key)
if symm_buffer is None:
symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group,
self.num_experts,
self.max_num_tokens,
self.top_k,
self.hidden_size,
self.intermediate_size,
)
self._symm_buffer_cache[key] = symm_buffer
return symm_buffer
def forward(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
*,
activation_clamp: float | None,
fast_math: bool = True,
) -> torch.Tensor:
if hidden_states.shape[0] > self.max_num_tokens:
raise ValueError(
f"DeepSeek V4 MegaMoE got {hidden_states.shape[0]} tokens, "
f"but the symmetric buffer was sized for {self.max_num_tokens}."
)
y = torch.empty_like(hidden_states, dtype=torch.bfloat16)
torch.ops.vllm.deepseek_v4_mega_moe_experts(
hidden_states,
topk_weights,
topk_ids,
y,
self.prefix,
activation_clamp,
fast_math,
)
return y
def _run_mega_moe(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
y: torch.Tensor,
activation_clamp: float | None,
fast_math: bool,
) -> None:
import vllm.third_party.deep_gemm as deep_gemm
symm_buffer = self.get_symm_buffer()
num_tokens = hidden_states.shape[0]
_stage_deepseek_v4_mega_moe_inputs(
hidden_states,
topk_weights,
topk_ids,
symm_buffer.x[:num_tokens],
symm_buffer.x_sf[:num_tokens],
symm_buffer.topk_idx[:num_tokens],
symm_buffer.topk_weights[:num_tokens],
)
# This method must have been already called during the weight loading phase.
# We call it again here to cover the dummy weight loading case.
self.finalize_weights()
assert self._transformed_l1_weights is not None
assert self._transformed_l2_weights is not None
deep_gemm.fp8_fp4_mega_moe(
y,
self._transformed_l1_weights,
self._transformed_l2_weights,
symm_buffer,
activation_clamp=activation_clamp,
fast_math=fast_math,
)
DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
def _deepseek_v4_mega_moe_experts_op(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
out: torch.Tensor,
layer_name: str,
activation_clamp: float | None,
fast_math: bool,
) -> None:
self = get_forward_context().no_compile_layers[layer_name]
self._run_mega_moe(
hidden_states,
topk_weights,
topk_ids,
out,
activation_clamp,
fast_math,
)
def _deepseek_v4_mega_moe_experts_op_fake(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
out: torch.Tensor,
layer_name: str,
activation_clamp: float | None,
fast_math: bool,
) -> None:
return None
direct_register_custom_op(
op_name="deepseek_v4_mega_moe_experts",
op_func=_deepseek_v4_mega_moe_experts_op,
mutates_args=["out"],
fake_impl=_deepseek_v4_mega_moe_experts_op_fake,
)
class DeepseekV4MoE(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.prefix = prefix
self.use_mega_moe = (
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
)
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently requires expert parallel. "
"Enable it with --enable-expert-parallel, or pick a different "
"moe backend."
)
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.hidden_size = config.hidden_size
self.n_routed_experts = config.n_routed_experts
self.n_activated_experts = config.num_experts_per_tok
self.moe_intermediate_size = config.moe_intermediate_size
self.swiglu_limit = config.swiglu_limit
self.renormalize = config.norm_topk_prob
self.scoring_func = getattr(config, "scoring_func", "sqrtsoftplus")
if self.use_mega_moe and self.scoring_func != "sqrtsoftplus":
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only."
)
if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4":
raise NotImplementedError(
"DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype="
f"{config.expert_dtype!r}. Drop --kernel-config moe_backend="
"deep_gemm_mega_moe for this checkpoint."
)
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
out_dtype=torch.float32,
bias=False,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = None
self.gate.tid2eid = None
is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers
self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32
if is_hash_moe:
# hash MoE doesn't use e_score_correction_bias
# Use randint instead of empty to avoid garbage values causing
# invalid memory access in dummy mode (--load-format="dummy")
self.gate.tid2eid = nn.Parameter(
torch.randint(
0,
config.n_routed_experts,
(config.vocab_size, config.num_experts_per_tok),
dtype=self.hash_indices_dtype,
),
requires_grad=False,
)
elif getattr(config, "topk_method", None) == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32),
requires_grad=False,
)
if config.n_shared_experts is None:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV4MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
swiglu_limit=self.swiglu_limit,
quant_config=quant_config,
reduce_results=self.use_mega_moe,
prefix=f"{prefix}.shared_experts",
)
if self.use_mega_moe:
self._init_mega_moe_experts(vllm_config, config, prefix)
else:
self._init_fused_moe_experts(config, quant_config, prefix)
def _init_mega_moe_experts(
self,
vllm_config: VllmConfig,
config,
prefix: str,
) -> None:
self.ep_group = get_ep_group()
self.ep_size = self.ep_group.world_size
self.ep_rank = self.ep_group.rank_in_group
assert config.n_routed_experts % self.ep_size == 0
self.n_local_experts = config.n_routed_experts // self.ep_size
self.experts_start_idx = self.ep_rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.experts = DeepseekV4MegaMoEExperts(
vllm_config,
num_experts=config.n_routed_experts,
num_local_experts=self.n_local_experts,
experts_start_idx=self.experts_start_idx,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
prefix=f"{prefix}.experts",
)
def _init_fused_moe_experts(
self,
config,
quant_config,
prefix: str,
) -> None:
self.tp_rank = get_tensor_model_parallel_rank()
assert config.n_routed_experts % self.tp_size == 0
self.n_local_experts = config.n_routed_experts // self.tp_size
self.experts_start_idx = self.tp_rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
hash_indices_table=self.gate.tid2eid,
swiglu_limit=self.swiglu_limit,
router_logits_dtype=torch.float32,
)
def forward(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
) -> torch.Tensor:
if self.gate.tid2eid is not None and input_ids is None:
raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.")
if not self.use_mega_moe:
return self._forward_fused_moe(hidden_states, input_ids)
org_shape = hidden_states.shape
router_logits, _ = self.gate(hidden_states)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias.data
if self.gate.e_score_correction_bias is not None
else None,
topk=self.n_activated_experts,
renormalize=self.renormalize,
indices_type=self.hash_indices_dtype,
input_tokens=input_ids,
hash_indices_table=self.gate.tid2eid,
routed_scaling_factor=self.routed_scaling_factor,
)
activation_clamp = (
float(self.swiglu_limit) if self.swiglu_limit is not None else None
)
final_hidden_states = self.experts(
hidden_states,
topk_weights,
topk_ids,
activation_clamp=activation_clamp,
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
final_hidden_states += shared_output
return final_hidden_states.view(org_shape)
def _forward_fused_moe(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
) -> torch.Tensor:
org_shape = hidden_states.shape
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=hidden_states,
input_ids=input_ids,
)
else:
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
input_ids=input_ids,
)
return final_hidden_states.view(org_shape)
def finalize_mega_moe_weights(self) -> None:
if self.use_mega_moe:
self.experts.finalize_weights()
class DeepseekV4Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
layer_id = extract_layer_index(prefix)
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
# wo_a is NOT quantized in the NVFP4 checkpoint (modelopt left it as bfloat16),
# but the attention forward pass expects FP8 (weight + weight_scale_inv).
# Pass quant_config=None to load bfloat16, then process_weights_after_loading
# will handle the FP8 quantization.
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=None,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = config.quantization_config["scale_fmt"]
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
rope_parameters["mscale"] = 0 # Disable mscale
rope_parameters["mscale_all_dim"] = 0 # Disable mscale
rope_parameters["is_deepseek_v4"] = True
rope_parameters["rope_dim"] = self.rope_head_dim
self.rotary_emb = get_rope(
self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=vllm_config.cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
)
mla_modules = DeepseekV4MLAModules(
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
wq_b=self.wq_b,
kv_norm=self.kv_norm,
wo_a=self.wo_a,
wo_b=self.wo_b,
attn_sink=self.attn_sink,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
)
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
mla_modules=mla_modules,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
):
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV4DecoderLayer(nn.Module):
def __init__(
self,
vllm_config,
prefix,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_list: list[torch.cuda.Stream] | None = None,
):
super().__init__()
# Lazy import to avoid top-level tilelang dependency.
# Registers both torch.ops.vllm.mhc_pre and mhc_post
import vllm.model_executor.layers.mhc # noqa: F401
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
self.attn = DeepseekV4Attention(
vllm_config,
prefix=f"{prefix}.attn",
topk_indices_buffer=topk_indices_buffer,
aux_stream_list=aux_stream_list,
)
self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")
self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
self.hc_mult = config.hc_mult
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
self.hc_eps = config.hc_eps
self.hc_post_alpha = 2.0
mix_hc = (2 + self.hc_mult) * self.hc_mult
hc_dim = self.hc_mult * self.hidden_size
self.hc_attn_fn = nn.Parameter(
torch.empty(
(mix_hc, hc_dim),
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_ffn_fn = nn.Parameter(
torch.empty(
(mix_hc, hc_dim),
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_attn_base = nn.Parameter(
torch.empty(
mix_hc,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_ffn_base = nn.Parameter(
torch.empty(
mix_hc,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_attn_scale = nn.Parameter(
torch.empty(
3,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_ffn_scale = nn.Parameter(
torch.empty(
3,
dtype=torch.float32,
),
requires_grad=False,
)
def hc_pre(
self,
x: torch.Tensor,
hc_fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
):
post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre(
residual=x,
fn=hc_fn,
hc_scale=hc_scale,
hc_base=hc_base,
rms_eps=self.rms_norm_eps,
hc_pre_eps=self.hc_eps,
hc_sinkhorn_eps=self.hc_eps,
hc_post_mult_value=self.hc_post_alpha,
sinkhorn_repeat=self.hc_sinkhorn_iters,
)
return layer_input, post_mix, res_mix
def hc_post(
self,
x: torch.Tensor,
residual: torch.Tensor,
post: torch.Tensor,
comb: torch.Tensor,
):
return torch.ops.vllm.mhc_post(x, residual, post, comb)
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
input_ids: torch.Tensor | None,
post_mix: torch.Tensor | None,
res_mix: torch.Tensor | None,
residual: torch.Tensor | None,
) -> torch.Tensor:
if residual is None:
# Run standalone hc_pre on first layer
residual = x
x, post_mix, res_mix = self.hc_pre(
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
)
else:
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
x,
residual,
post_mix,
res_mix,
self.hc_attn_fn,
self.hc_attn_scale,
self.hc_attn_base,
self.rms_norm_eps,
self.hc_eps,
self.hc_eps,
self.hc_post_alpha,
self.hc_sinkhorn_iters,
)
x = self.attn_norm(x)
x = self.attn(positions, x, None)
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
x,
residual,
post_mix,
res_mix,
self.hc_ffn_fn,
self.hc_ffn_scale,
self.hc_ffn_base,
self.rms_norm_eps,
self.hc_eps,
self.hc_eps,
self.hc_post_alpha,
self.hc_sinkhorn_iters,
)
x = self.ffn_norm(x)
x = self.ffn(x, input_ids)
return x, residual, post_mix, res_mix
@support_torch_compile
class DeepseekV4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# Select weight mapper based on quantization method.
# NVFP4 (modelopt_fp4) checkpoints use different key naming
# than the default MXFP4 format.
quant_config = vllm_config.quant_config
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
elif getattr(config, "expert_dtype", "fp4") != "fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
else:
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
quant_config = vllm_config.quant_config
self.config = config
self.use_mega_moe = (
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
)
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently requires expert parallel. "
"Enable it with --enable-expert-parallel, or pick a different "
"moe backend."
)
self.vocab_size = config.vocab_size
self.hc_eps = config.hc_eps
self.hc_mult = config.hc_mult
self.hc_dim = self.hc_mult * config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
# Three aux streams: one per non-default input GEMM in
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
# (compressor kv_score, indexer.weights_proj, indexer.compressor
# kv_score). fused_wqa_wkv stays on the default stream.
# Disable them on ROCm because of hang issues.
aux_stream_list = (
None
if current_platform.is_rocm()
else [torch.cuda.Stream() for _ in range(3)]
)
self.device = current_platform.device_type
# Reserved topk indices buffer for all Indexer layers to reuse.
self.topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
config.index_topk,
dtype=torch.int32,
device=self.device,
)
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV4DecoderLayer(
vllm_config,
prefix=prefix,
topk_indices_buffer=self.topk_indices_buffer,
aux_stream_list=aux_stream_list,
),
prefix=f"{prefix}.layers",
)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.hc_head_fn = nn.Parameter(
torch.empty(
self.hc_mult,
self.hc_dim,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_head_base = nn.Parameter(
torch.empty(
self.hc_mult,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_head_scale = nn.Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
# Pre-hc_head residual stream buffer for the MTP draft. Stable
# address (outside the cudagraph pool) so the copy_ in forward()
# refreshes it correctly across captured shapes.
# refreshes it correctly across captured shapes. Only allocated on
# the last PP rank — that's where MTP target hidden states are
# produced.
if get_pp_group().is_last_rank:
self._mtp_hidden_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
self.hc_dim,
dtype=vllm_config.model_config.dtype,
device=self.device,
)
else:
self._mtp_hidden_buffer = None
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def make_empty_intermediate_tensors(
self,
batch_size: int,
dtype: torch.dtype,
device: torch.device,
) -> IntermediateTensors:
# PP intermediate tensors carry the multi-stream hidden_states
# of shape (num_tokens, hc_mult, hidden_size) — V4 expands the
# token embedding to hc_mult streams before the first decoder
# layer and keeps that shape until hc_head() collapses it.
return IntermediateTensors(
{
"hidden_states": torch.zeros(
(batch_size, self.hc_mult, self.config.hidden_size),
dtype=dtype,
device=device,
),
}
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.embed_input_ids(input_ids)
hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
if self.use_mega_moe:
input_ids = input_ids.to(torch.int64)
residual, post_mix, res_mix = None, None, None
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual, post_mix, res_mix = layer(
hidden_states,
positions,
input_ids,
post_mix,
res_mix,
residual,
)
else:
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
# Stash pre-hc_head residual for the MTP draft (captured copy_).
num_tokens = hidden_states.shape[0]
self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1))
hidden_states = hc_head(
hidden_states,
self.hc_head_fn,
self.hc_head_scale,
self.hc_head_base,
self.rms_norm_eps,
self.hc_eps,
)
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
("attn.fused_wqa_wkv", "attn.wq_a", 0),
("attn.fused_wqa_wkv", "attn.wkv", 1),
("compressor.fused_wkv_wgate", "compressor.wkv", 0),
("compressor.fused_wkv_wgate", "compressor.wgate", 1),
# Indexer's compressor (same stacking pattern)
("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wkv", 0),
("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wgate", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
# TP for attention
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
n_head = self.config.num_attention_heads
n_local_head = n_head // tp_size
head_rank_start = n_local_head * tp_rank
head_rank_end = n_local_head * (tp_rank + 1)
# Pre-compute expert mapping ONCE.
expert_mapping = self.get_expert_mapping()
# NVFP4 compressor/indexer scale params need special handling:
# wkv.input_scale (shape [1]) + wgate.input_scale (shape [1])
# must be concatenated into fused_wkv_wgate.input_scale (shape [2]).
# The default stacking path fails because PerTensorScaleParameter's
# weight_loader asserts shape equality.
# We buffer them and load once both shards are available.
compressor_scale_buffer: dict[str, dict[int, torch.Tensor]] = {}
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if ".experts." in name:
continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
break
if name not in params_dict:
# The stacked param doesn't exist — skip
# (e.g. indexer.compressor.fused_wkv_wgate on layers
# that don't have the full indexer structure)
break
param = params_dict[name]
weight_loader = param.weight_loader
# NVFP4 scale params for stacked fused_wkv_wgate need
# special handling: each shard (wkv, wgate) has scale
# shape [1] or [head_dim, K], but the fused param has
# shape [2] or [2*head_dim, K]. The default stacking
# weight_loader can't handle this for PerTensorScale or
# ModelWeight scale params. Buffer and concatenate.
is_compressor_scale = (
"fused_wkv_wgate" in name
and name.endswith((
"input_scale",
"weight_scale",
"weight_scale_2",
))
)
if is_compressor_scale:
# Verify the fused param exists before buffering
if name not in params_dict:
print(
f"COMPRESSOR_SCALE_SKIP: {name} not in params_dict",
flush=True,
)
break
if is_compressor_scale:
# Buffer the shard for later concatenation
if name not in compressor_scale_buffer:
compressor_scale_buffer[name] = {}
compressor_scale_buffer[name][shard_id] = loaded_weight
loaded_params.add(name)
break
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
if ".experts." in name:
# E8M0 scales are stored as float8_e8m0fnu in
# checkpoints but the MoE param is uint8. copy_()
# would do a numeric conversion (e.g. 2^-7 → 0),
# destroying the raw exponent bytes.
if (
"weight_scale" in name
and loaded_weight.dtype == torch.float8_e8m0fnu
):
loaded_weight = loaded_weight.view(torch.uint8)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name_mapped, self):
continue
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
break
loaded_params.add(name_mapped)
continue
elif "attn_sink" in name:
if is_pp_missing_parameter(name, self):
continue
narrow_weight = loaded_weight[head_rank_start:head_rank_end]
n = narrow_weight.shape[0]
params_dict[name][:n].copy_(narrow_weight)
loaded_params.add(name)
continue
else:
if is_pp_missing_parameter(name, self):
continue
if name not in params_dict:
print(f"Skipping weight {name} (not in model params)",
flush=True)
continue
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
try:
weight_loader(param, loaded_weight)
except (AssertionError, RuntimeError) as e:
print(
f"WEIGHT_LOAD_FAIL: name={name} "
f"param_shape={param.data.shape if hasattr(param, 'data') else '?'} "
f"loaded_shape={loaded_weight.shape} "
f"loaded_dtype={loaded_weight.dtype} "
f"error={e}",
flush=True,
)
raise
# Load buffered compressor/indexer scale params.
# These are NVFP4 quantization scales that need concatenation
# across shards (wkv=shard0, wgate=shard1) before loading.
for name, shards in compressor_scale_buffer.items():
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
if len(shards) == 2:
# Concatenate shard 0 and shard 1 along dim 0.
# Scales may be 0-dim scalars (input_scale, weight_scale_2)
# or N-dim tensors (weight_scale); reshape scalars to 1-d.
s0, s1 = shards[0], shards[1]
if s0.ndim == 0:
s0 = s0.reshape(1)
if s1.ndim == 0:
s1 = s1.reshape(1)
stacked = torch.cat([s0, s1], dim=0)
else:
stacked = shards[0]
assert param.data.shape == stacked.shape, (
f"Scale shape mismatch for {name}: "
f"param={param.data.shape} loaded={stacked.shape}"
)
param.data.copy_(stacked)
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
first_layer = next(iter(islice(self.layers, self.start_layer, self.end_layer)))
if first_layer.ffn.use_mega_moe:
return make_deepseek_v4_expert_params_mapping(self.config.n_routed_experts)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.n_routed_experts,
)
def finalize_mega_moe_weights(self) -> None:
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer.ffn.finalize_mega_moe_weights()
# Quantize wo_a to FP8 (checkpoint has bfloat16, forward expects FP8)
attn = layer.attn
if hasattr(attn, 'wo_a') and attn.wo_a.weight.dtype == torch.bfloat16:
self._quantize_wo_a_to_fp8(attn.wo_a)
@staticmethod
def _quantize_wo_a_to_fp8(wo_a: ColumnParallelLinear) -> None:
"""Quantize wo_a weight from bfloat16 to float8_e4m3fn.
The attention forward pass (fused_inv_rope_fp8_quant + einsum)
expects wo_a.weight as FP8 and wo_a.weight_scale_inv as float32.
The NVFP4 checkpoint stores wo_a as bfloat16, so we quantize here.
Uses per-tensor symmetric quantization (same as modelopt FP8).
"""
weight_bf16 = wo_a.weight.data
# Per-tensor FP8 quantization: scale = amax / fp8_max
fp8_max = torch.finfo(torch.float8_e4m3fn).max # 448.0
amax = weight_bf16.abs().max().float()
scale = amax / fp8_max
# Avoid division by zero
if scale == 0:
scale = torch.tensor(1.0, device=scale.device)
scale_inv = 1.0 / scale
weight_fp8 = (weight_bf16.float() * scale).to(torch.float8_e4m3fn)
wo_a.weight = torch.nn.Parameter(weight_fp8, requires_grad=False)
wo_a.weight_scale_inv = torch.nn.Parameter(
scale_inv.clone(), requires_grad=False
)
@torch.compile(backend=current_platform.simple_compile_backend)
def hc_head(
hidden_states: torch.Tensor,
hc_fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_norm_eps: float,
hc_eps: float,
) -> torch.Tensor:
hc_mult, hidden_size = hidden_states.shape[-2:]
outer_shape = hidden_states.shape[:-2]
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
num_tokens = hs_flat.shape[0]
out = torch.empty(
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
)
torch.ops.vllm.hc_head_fused_kernel(
hs_flat,
hc_fn,
hc_scale,
hc_base,
out,
hidden_size,
rms_norm_eps,
hc_eps,
hc_mult,
)
return out.view(*outer_shape, hidden_size)
def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
if expert_dtype == "fp4":
# MXFP4 experts use Mxfp4MoEMethod, which registers scales as
# ``w{1,2,3}_weight_scale`` (no _inv suffix). FP8 linear and
# shared experts use Fp8LinearMethod's block scales, which
# register as ``weight_scale_inv``.
scale_regex = {
re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale",
re.compile(r"\.scale$"): ".weight_scale_inv",
}
else:
# FP8 experts use Fp8MoEMethod (block_quant=True), which registers
# scales as ``w{13,2}_weight_scale_inv``. Map all ``.scale`` keys
# there.
scale_regex = {
re.compile(r"\.scale$"): ".weight_scale_inv",
}
return WeightsMapper(
orig_to_new_prefix={
"layers.": "model.layers.",
"embed.": "model.embed.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
"mtp.": "model.mtp.",
},
orig_to_new_regex=scale_regex,
orig_to_new_suffix={
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
"""Weight mapper for NVFP4 (ModelOpt) DeepSeek-V4 checkpoints.
NVFP4 checkpoints use different key naming than the upstream MXFP4 format:
- Expert weights: gate_proj/up_proj/down_proj (not w1/w3/w2)
- Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes
- Shared expert uses down_proj (not w2)
- Self-attention uses .self_attn. prefix (same as checkpoint, renamed to .attn.)
This is the mapper that should be used when quantization is modelopt_fp4.
"""
expert_rename_regex = {
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
}
suffix_renames = {
# The NVFP4 checkpoint already uses lm_head / embed_tokens directly,
# no suffix renames needed (unlike the MXFP4 upstream format).
}
# NOTE: specific renames MUST come before general ones (applied in order)
substr_renames = {
# === Indexer params (MUST come before .self_attn.compressor.
# so that indexer keys are captured before the compressor prefix
# rewrite moves them under mla_attn.compressor) ===
# The checkpoint puts indexer under self_attn.compressor.indexer.*
# but the model has indexer at attn.indexer.* (sibling of compressor,
# NOT nested under it).
".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.",
".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.",
".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.",
".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.",
".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.",
".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape",
# === Compressor (non-indexer) NVFP4 renames ===
# Checkpoint uses kv_proj/gate_proj, model uses wkv/wgate
# (for stacking into fused_wkv_wgate).
"compressor.kv_proj.": "compressor.wkv.",
"compressor.gate_proj.": "compressor.wgate.",
"compressor.kv_norm.": "compressor.norm.",
"compressor.position_bias": "compressor.ape",
# === Attention compressor (MUST come after indexer renames
# so that remaining .self_attn.compressor. (non-indexer) keys
# become .attn.mla_attn.compressor.) ===
".self_attn.compressor.": ".attn.mla_attn.compressor.",
# === Attention projections (specific before .self_attn. → .attn.) ===
".self_attn.q_a_proj.": ".attn.wq_a.",
".self_attn.kv_proj.": ".attn.wkv.",
".self_attn.q_b_proj.": ".attn.wq_b.",
".self_attn.o_a_proj.": ".attn.wo_a.",
".self_attn.o_b_proj.": ".attn.wo_b.",
".self_attn.q_a_norm.": ".attn.q_norm.",
".self_attn.kv_norm.": ".attn.kv_norm.",
".self_attn.sinks": ".attn.attn_sink",
# Shared expert projections (specific before .mlp. → .ffn.)
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
# General renames
".mlp.": ".ffn.",
".self_attn.": ".attn.",
# Layer norms (checkpoint uses input_layernorm / post_attention_layernorm,
# model uses attn_norm / ffn_norm)
"input_layernorm.": "attn_norm.",
"post_attention_layernorm.": "ffn_norm.",
# Per-layer HC params (checkpoint uses attn_hc / ffn_hc with dot,
# model uses hc_attn / hc_ffn with underscore)
".attn_hc.fn": ".hc_attn_fn",
".attn_hc.base": ".hc_attn_base",
".attn_hc.scale": ".hc_attn_scale",
".ffn_hc.fn": ".hc_ffn_fn",
".ffn_hc.base": ".hc_ffn_base",
".ffn_hc.scale": ".hc_ffn_scale",
# Top-level hc_head params (checkpoint uses hc_head.fn etc,
# model uses hc_head_fn etc)
"hc_head.fn": "hc_head_fn",
"hc_head.base": "hc_head_base",
"hc_head.scale": "hc_head_scale",
}
return WeightsMapper(
orig_to_new_prefix={
"layers.": "model.layers.",
"embed_tokens.": "model.embed_tokens.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
"mtp.": "model.mtp.",
},
orig_to_new_regex=expert_rename_regex,
orig_to_new_suffix=suffix_renames,
orig_to_new_substr=substr_renames,
)
class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
model_cls = DeepseekV4Model
# Default mapper assumes the original FP4-expert checkpoint layout.
# Overridden per-instance in __init__ when expert_dtype != "fp4".
hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
# Select weight mapper based on quantization method.
# NVFP4 (modelopt_fp4) checkpoints use different key naming
# than the default MXFP4 format.
quant_config = vllm_config.quant_config
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
elif getattr(config, "expert_dtype", "fp4") != "fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
else:
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
self.config = config
expert_dtype = getattr(config, "expert_dtype", "fp4")
if expert_dtype != "fp4":
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype)
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def get_mtp_target_hidden_states(self) -> torch.Tensor | None:
"""Pre-hc_head residual stream buffer (max_num_batched_tokens,
hc_mult * hidden_size) for the MTP draft model. Populated by
forward(); valid after each target step."""
return getattr(self.model, "_mtp_hidden_buffer", None)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_substrs=["mtp."])
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
self.model.finalize_mega_moe_weights()
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()