- input_layernorm → attn_norm, post_attention_layernorm → ffn_norm - hc_head.fn/base/scale → hc_head_fn/base/scale - attn_hc/ffn_hc → hc_attn/hc_ffn (dot to underscore) - q_a_norm → q_norm, sinks → attn_sink - Indexer params: self_attn.compressor.indexer → attn.indexer (not attn.mla_attn.compressor.indexer)
1925 lines
71 KiB
Python
1925 lines
71 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import typing
|
|
from collections.abc import Callable, Iterable
|
|
from itertools import islice
|
|
|
|
import regex as re
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import VllmConfig, get_current_vllm_config
|
|
from vllm.distributed import (
|
|
get_ep_group,
|
|
get_pp_group,
|
|
get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
)
|
|
from vllm.forward_context import get_forward_context
|
|
from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp
|
|
from vllm.model_executor.layers.deepseek_v4_attention import (
|
|
DeepseekV4Indexer,
|
|
DeepseekV4MLAModules,
|
|
DeepseekV4MultiHeadLatentAttentionWrapper,
|
|
)
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
|
|
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
|
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
|
|
fused_topk_bias,
|
|
)
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (
|
|
ColumnParallelLinear,
|
|
MergedColumnParallelLinear,
|
|
RowParallelLinear,
|
|
)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.quantization import (
|
|
QuantizationConfig,
|
|
QuantizationMethods,
|
|
)
|
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
|
|
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod
|
|
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
|
is_layer_skipped,
|
|
)
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead,
|
|
VocabParallelEmbedding,
|
|
)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.models.interfaces import SupportsPP
|
|
from vllm.model_executor.utils import set_weight_attrs
|
|
from vllm.platforms import current_platform
|
|
from vllm.sequence import IntermediateTensors
|
|
from vllm.triton_utils import tl, triton
|
|
from vllm.utils.torch_utils import direct_register_custom_op
|
|
|
|
from .utils import (
|
|
AutoWeightsLoader,
|
|
PPMissingLayer,
|
|
WeightsMapper,
|
|
extract_layer_index,
|
|
is_pp_missing_parameter,
|
|
make_layers,
|
|
maybe_prefix,
|
|
)
|
|
|
|
_DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8")
|
|
|
|
|
|
class DeepseekV4MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
swiglu_limit: float | None = None,
|
|
quant_config: QuantizationConfig | None = None,
|
|
reduce_results: bool = True,
|
|
is_sequence_parallel: bool = False,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
|
|
# If is_sequence_parallel, the input and output tensors are sharded
|
|
# across the ranks within the tp_group. In this case the weights are
|
|
# replicated and no collective ops are needed.
|
|
# Otherwise we use standard TP with an allreduce at the end.
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
disable_tp=is_sequence_parallel,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.down_proj = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
reduce_results=reduce_results,
|
|
disable_tp=is_sequence_parallel,
|
|
prefix=f"{prefix}.down_proj",
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(
|
|
f"Unsupported activation: {hidden_act}. Only silu is supported for now."
|
|
)
|
|
if swiglu_limit is not None:
|
|
self.act_fn = SiluAndMulWithClamp(swiglu_limit)
|
|
else:
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.down_proj(x)
|
|
return x
|
|
|
|
|
|
class DeepseekV4FP8Config(Fp8Config):
|
|
"""FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch.
|
|
|
|
DeepSeek V4 checkpoints always use FP8 block quantization for
|
|
linear/attention layers. The MoE expert weights vary by checkpoint:
|
|
- ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts
|
|
with ue8m0 (e8m0fnu) FP8 linear scales.
|
|
- ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block
|
|
experts with float32 FP8 linear scales.
|
|
|
|
The dispatch and the linear scale dtype are both keyed off
|
|
``expert_dtype`` from the model's hf_config; missing values default
|
|
to ``"fp4"`` so existing FP4 checkpoints stay unchanged.
|
|
|
|
NOTE: ``expert_dtype`` is resolved lazily because this config is
|
|
constructed during VllmConfig setup, before ``set_current_vllm_config``
|
|
is active. Reading hf_config eagerly in ``__init__`` would always see
|
|
the default ``"fp4"`` and silently misroute Flash-Base checkpoints.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self._resolved_expert_dtype: str | None = None
|
|
# ``is_scale_e8m0`` is a property that resolves on first read,
|
|
# by which time the current vllm_config has been set.
|
|
|
|
@property
|
|
def expert_dtype(self) -> str:
|
|
if self._resolved_expert_dtype is None:
|
|
try:
|
|
hf_config = get_current_vllm_config().model_config.hf_config
|
|
except Exception:
|
|
# vllm_config not yet set; defer the decision until a
|
|
# later call lands inside set_current_vllm_config.
|
|
return "fp4"
|
|
expert_dtype = getattr(hf_config, "expert_dtype", "fp4")
|
|
if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES:
|
|
raise ValueError(
|
|
f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; "
|
|
f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}."
|
|
)
|
|
self._resolved_expert_dtype = expert_dtype
|
|
from vllm.logger import init_logger
|
|
|
|
init_logger(__name__).info_once(
|
|
"DeepSeek V4 expert_dtype resolved to %r", expert_dtype
|
|
)
|
|
return self._resolved_expert_dtype
|
|
|
|
@property
|
|
def is_scale_e8m0(self) -> bool:
|
|
# FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert
|
|
# checkpoints (Flash-Base) store them as float32.
|
|
return self.expert_dtype == "fp4"
|
|
|
|
@classmethod
|
|
def get_name(cls) -> QuantizationMethods:
|
|
return "deepseek_v4_fp8"
|
|
|
|
@classmethod
|
|
def override_quantization_method(
|
|
cls, hf_quant_cfg, user_quant, hf_config=None
|
|
) -> QuantizationMethods | None:
|
|
if not (
|
|
isinstance(hf_quant_cfg, dict)
|
|
and hf_quant_cfg.get("quant_method") in ("fp8", "deepseek_v4_fp8")
|
|
):
|
|
return None
|
|
model_type = getattr(hf_config, "model_type", None)
|
|
if model_type == "deepseek_v4" or user_quant == "deepseek_v4_fp8":
|
|
return "deepseek_v4_fp8"
|
|
return None
|
|
|
|
def get_quant_method(self, layer, prefix):
|
|
if isinstance(layer, FusedMoE):
|
|
if is_layer_skipped(
|
|
prefix=prefix,
|
|
ignored_layers=self.ignored_layers,
|
|
fused_mapping=self.packed_modules_mapping,
|
|
):
|
|
return UnquantizedFusedMoEMethod(layer.moe_config)
|
|
if self.expert_dtype == "fp4":
|
|
return Mxfp4MoEMethod(layer.moe_config)
|
|
# expert_dtype == "fp8": fall through to Fp8Config which
|
|
# returns Fp8MoEMethod with block-wise float32 scales.
|
|
return super().get_quant_method(layer, prefix)
|
|
|
|
def is_mxfp4_quant(self, prefix, layer):
|
|
return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4"
|
|
|
|
|
|
@triton.jit
|
|
def _deepseek_v4_stage_mega_moe_inputs_kernel(
|
|
hidden_states,
|
|
x_fp8,
|
|
x_sf,
|
|
topk_ids,
|
|
topk_weights,
|
|
topk_idx_out,
|
|
topk_weights_out,
|
|
hidden_stride_m: tl.constexpr,
|
|
hidden_stride_k: tl.constexpr,
|
|
x_stride_m: tl.constexpr,
|
|
x_stride_k: tl.constexpr,
|
|
x_sf_stride_m: tl.constexpr,
|
|
x_sf_stride_k: tl.constexpr,
|
|
topk_ids_stride_m: tl.constexpr,
|
|
topk_ids_stride_k: tl.constexpr,
|
|
topk_weights_stride_m: tl.constexpr,
|
|
topk_weights_stride_k: tl.constexpr,
|
|
topk_idx_stride_m: tl.constexpr,
|
|
topk_idx_stride_k: tl.constexpr,
|
|
topk_weights_out_stride_m: tl.constexpr,
|
|
topk_weights_out_stride_k: tl.constexpr,
|
|
hidden_size: tl.constexpr,
|
|
top_k: tl.constexpr,
|
|
BLOCK_K: tl.constexpr,
|
|
GROUP_K: tl.constexpr,
|
|
BLOCK_TOPK: tl.constexpr,
|
|
) -> None:
|
|
token_id = tl.program_id(0)
|
|
k_block_id = tl.program_id(1)
|
|
|
|
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
|
|
k_mask = k_offsets < hidden_size
|
|
hidden = tl.load(
|
|
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
|
|
mask=k_mask,
|
|
other=0.0,
|
|
).to(tl.float32)
|
|
|
|
num_groups: tl.constexpr = BLOCK_K // GROUP_K
|
|
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
|
|
amax = tl.max(hidden_groups, axis=1)
|
|
amax = tl.maximum(amax, 1.0e-4)
|
|
|
|
scale = amax / 448.0
|
|
scale_bits = scale.to(tl.uint32, bitcast=True)
|
|
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
|
|
tl.uint32
|
|
)
|
|
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
|
|
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)
|
|
|
|
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
|
|
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
|
|
scaled = tl.reshape(scaled, [BLOCK_K])
|
|
fp8 = scaled.to(tl.float8e4nv)
|
|
tl.store(
|
|
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
|
|
fp8,
|
|
mask=k_mask,
|
|
)
|
|
|
|
scale_offsets = tl.arange(0, num_groups)
|
|
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
|
|
tl.store(
|
|
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
|
|
packed_scale,
|
|
)
|
|
|
|
if k_block_id == 0:
|
|
topk_offsets = tl.arange(0, BLOCK_TOPK)
|
|
topk_mask = topk_offsets < top_k
|
|
|
|
ids = tl.load(
|
|
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
|
|
mask=topk_mask,
|
|
other=0,
|
|
).to(tl.int64)
|
|
tl.store(
|
|
topk_idx_out
|
|
+ token_id * topk_idx_stride_m
|
|
+ topk_offsets * topk_idx_stride_k,
|
|
ids,
|
|
mask=topk_mask,
|
|
)
|
|
|
|
weights = tl.load(
|
|
topk_weights
|
|
+ token_id * topk_weights_stride_m
|
|
+ topk_offsets * topk_weights_stride_k,
|
|
mask=topk_mask,
|
|
other=0.0,
|
|
)
|
|
tl.store(
|
|
topk_weights_out
|
|
+ token_id * topk_weights_out_stride_m
|
|
+ topk_offsets * topk_weights_out_stride_k,
|
|
weights,
|
|
mask=topk_mask,
|
|
)
|
|
|
|
|
|
def _stage_deepseek_v4_mega_moe_inputs(
|
|
hidden_states: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
x_fp8: torch.Tensor,
|
|
x_sf: torch.Tensor,
|
|
topk_idx_out: torch.Tensor,
|
|
topk_weights_out: torch.Tensor,
|
|
) -> None:
|
|
num_tokens, hidden_size = hidden_states.shape
|
|
if num_tokens == 0:
|
|
return
|
|
if hidden_size % 128 != 0:
|
|
raise ValueError(
|
|
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
|
|
"a multiple of 128."
|
|
)
|
|
top_k = topk_ids.shape[1]
|
|
if topk_weights.shape != topk_ids.shape:
|
|
raise ValueError(
|
|
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
|
|
"topk_ids to have the same shape."
|
|
)
|
|
|
|
block_k = 128
|
|
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
|
|
block_topk = triton.next_power_of_2(top_k)
|
|
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
|
|
hidden_states,
|
|
x_fp8,
|
|
x_sf,
|
|
topk_ids,
|
|
topk_weights,
|
|
topk_idx_out,
|
|
topk_weights_out,
|
|
hidden_states.stride(0),
|
|
hidden_states.stride(1),
|
|
x_fp8.stride(0),
|
|
x_fp8.stride(1),
|
|
x_sf.stride(0),
|
|
x_sf.stride(1),
|
|
topk_ids.stride(0),
|
|
topk_ids.stride(1),
|
|
topk_weights.stride(0),
|
|
topk_weights.stride(1),
|
|
topk_idx_out.stride(0),
|
|
topk_idx_out.stride(1),
|
|
topk_weights_out.stride(0),
|
|
topk_weights_out.stride(1),
|
|
hidden_size,
|
|
top_k,
|
|
BLOCK_K=block_k,
|
|
GROUP_K=32,
|
|
BLOCK_TOPK=block_topk,
|
|
num_warps=4,
|
|
)
|
|
|
|
|
|
def make_deepseek_v4_expert_params_mapping(
|
|
num_experts: int,
|
|
) -> list[tuple[str, str, int, str]]:
|
|
return [
|
|
(
|
|
"experts.w13_" if shard_id in ("w1", "w3") else "experts.w2_",
|
|
f"experts.{expert_id}.{weight_name}.",
|
|
expert_id,
|
|
shard_id,
|
|
)
|
|
for expert_id in range(num_experts)
|
|
for shard_id, weight_name in [
|
|
("w1", "w1"),
|
|
("w2", "w2"),
|
|
("w3", "w3"),
|
|
]
|
|
]
|
|
|
|
|
|
class DeepseekV4MegaMoEExperts(nn.Module):
|
|
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}
|
|
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
*,
|
|
num_experts: int,
|
|
num_local_experts: int,
|
|
experts_start_idx: int,
|
|
top_k: int,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
self.prefix = prefix
|
|
self.num_experts = num_experts
|
|
self.num_local_experts = num_local_experts
|
|
self.experts_start_idx = experts_start_idx
|
|
self.experts_end_idx = experts_start_idx + num_local_experts
|
|
self.top_k = top_k
|
|
self.hidden_size = hidden_size
|
|
self.intermediate_size = intermediate_size
|
|
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
|
|
|
weight_attrs = {"weight_loader": self.weight_loader}
|
|
self.w13_weight = nn.Parameter(
|
|
torch.zeros(
|
|
num_local_experts,
|
|
2 * intermediate_size,
|
|
hidden_size // 2,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
set_weight_attrs(self.w13_weight, weight_attrs)
|
|
|
|
self.w13_weight_scale = nn.Parameter(
|
|
torch.zeros(
|
|
num_local_experts,
|
|
2 * intermediate_size,
|
|
hidden_size // 32,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
set_weight_attrs(self.w13_weight_scale, weight_attrs)
|
|
self.w13_weight_scale.quant_method = "block"
|
|
|
|
self.w2_weight = nn.Parameter(
|
|
torch.zeros(
|
|
num_local_experts,
|
|
hidden_size,
|
|
intermediate_size // 2,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
set_weight_attrs(self.w2_weight, weight_attrs)
|
|
|
|
self.w2_weight_scale = nn.Parameter(
|
|
torch.zeros(
|
|
num_local_experts,
|
|
hidden_size,
|
|
intermediate_size // 32,
|
|
dtype=torch.uint8,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
set_weight_attrs(self.w2_weight_scale, weight_attrs)
|
|
self.w2_weight_scale.quant_method = "block"
|
|
|
|
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
|
|
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
|
|
|
|
# Register in the static forward context so the custom-op wrapper
|
|
# can look up this module by name from within a torch.compile graph.
|
|
compilation_config = vllm_config.compilation_config
|
|
if prefix in compilation_config.static_forward_context:
|
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
|
compilation_config.static_forward_context[prefix] = self
|
|
|
|
def _map_global_expert_id(self, expert_id: int) -> int:
|
|
if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx:
|
|
return -1
|
|
return expert_id - self.experts_start_idx
|
|
|
|
def weight_loader(
|
|
self,
|
|
param: nn.Parameter,
|
|
loaded_weight: torch.Tensor,
|
|
weight_name: str,
|
|
shard_id: str,
|
|
expert_id: int,
|
|
return_success: bool = False,
|
|
) -> bool | None:
|
|
local_expert_id = self._map_global_expert_id(expert_id)
|
|
if local_expert_id == -1:
|
|
return False if return_success else None
|
|
|
|
expert_data = param.data[local_expert_id]
|
|
if shard_id in ("w1", "w3"):
|
|
if "w13_" not in weight_name:
|
|
return False if return_success else None
|
|
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
|
|
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
|
|
elif shard_id == "w2":
|
|
if "w2_" not in weight_name:
|
|
return False if return_success else None
|
|
else:
|
|
raise ValueError(f"Unsupported expert shard id: {shard_id}")
|
|
|
|
if expert_data.shape != loaded_weight.shape:
|
|
raise ValueError(
|
|
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
|
|
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
|
|
f"vs checkpoint {tuple(loaded_weight.shape)}"
|
|
)
|
|
expert_data.copy_(loaded_weight)
|
|
return True if return_success else None
|
|
|
|
@staticmethod
|
|
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
|
|
return (sf.to(torch.int32) << 23).view(torch.float32)
|
|
|
|
def _check_runtime_supported(self) -> None:
|
|
if not torch.cuda.is_available():
|
|
raise NotImplementedError("DeepSeek V4 MegaMoE requires CUDA.")
|
|
device = self.w13_weight.device
|
|
if device.type != "cuda":
|
|
raise NotImplementedError(
|
|
"DeepSeek V4 MegaMoE expert weights must be loaded on CUDA."
|
|
)
|
|
if torch.cuda.get_device_capability(device)[0] != 10:
|
|
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
|
|
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
|
|
raise ValueError(
|
|
"DeepGEMM MegaMoE requires hidden and intermediate sizes "
|
|
"to be multiples of 128."
|
|
)
|
|
|
|
def finalize_weights(self) -> None:
|
|
if self._transformed_l1_weights is not None:
|
|
return
|
|
|
|
self._check_runtime_supported()
|
|
import vllm.third_party.deep_gemm as deep_gemm
|
|
|
|
w13_scale = deep_gemm.transform_sf_into_required_layout(
|
|
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
|
|
2 * self.intermediate_size,
|
|
self.hidden_size,
|
|
(1, 32),
|
|
self.num_local_experts,
|
|
)
|
|
w2_scale = deep_gemm.transform_sf_into_required_layout(
|
|
self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(),
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
(1, 32),
|
|
self.num_local_experts,
|
|
)
|
|
self._transformed_l1_weights, self._transformed_l2_weights = (
|
|
deep_gemm.transform_weights_for_mega_moe(
|
|
(self.w13_weight.data.view(torch.int8).contiguous(), w13_scale),
|
|
(self.w2_weight.data.view(torch.int8).contiguous(), w2_scale),
|
|
)
|
|
)
|
|
# Drop the original loader-side parameters: the MegaMoE kernels only
|
|
# consume the transformed views above. transform_weights_for_mega_moe
|
|
# allocates a fresh tensor for the L1 weight (see _interleave_l1_weights)
|
|
# and fresh SF tensors for L1/L2; the L2 weight is the only tensor that
|
|
# aliases the original storage, and _transformed_l2_weights still holds
|
|
# it, so the storage stays live after we drop the Parameter.
|
|
self.w13_weight = None
|
|
self.w13_weight_scale = None
|
|
self.w2_weight = None
|
|
self.w2_weight_scale = None
|
|
|
|
def get_symm_buffer(self):
|
|
import vllm.third_party.deep_gemm as deep_gemm
|
|
|
|
group = get_ep_group().device_group
|
|
device = torch.accelerator.current_device_index()
|
|
key = (
|
|
id(group),
|
|
device,
|
|
self.num_experts,
|
|
self.max_num_tokens,
|
|
self.top_k,
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
)
|
|
symm_buffer = self._symm_buffer_cache.get(key)
|
|
if symm_buffer is None:
|
|
symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
|
|
group,
|
|
self.num_experts,
|
|
self.max_num_tokens,
|
|
self.top_k,
|
|
self.hidden_size,
|
|
self.intermediate_size,
|
|
)
|
|
self._symm_buffer_cache[key] = symm_buffer
|
|
return symm_buffer
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
*,
|
|
activation_clamp: float | None,
|
|
fast_math: bool = True,
|
|
) -> torch.Tensor:
|
|
if hidden_states.shape[0] > self.max_num_tokens:
|
|
raise ValueError(
|
|
f"DeepSeek V4 MegaMoE got {hidden_states.shape[0]} tokens, "
|
|
f"but the symmetric buffer was sized for {self.max_num_tokens}."
|
|
)
|
|
y = torch.empty_like(hidden_states, dtype=torch.bfloat16)
|
|
torch.ops.vllm.deepseek_v4_mega_moe_experts(
|
|
hidden_states,
|
|
topk_weights,
|
|
topk_ids,
|
|
y,
|
|
self.prefix,
|
|
activation_clamp,
|
|
fast_math,
|
|
)
|
|
return y
|
|
|
|
def _run_mega_moe(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
y: torch.Tensor,
|
|
activation_clamp: float | None,
|
|
fast_math: bool,
|
|
) -> None:
|
|
import vllm.third_party.deep_gemm as deep_gemm
|
|
|
|
symm_buffer = self.get_symm_buffer()
|
|
num_tokens = hidden_states.shape[0]
|
|
_stage_deepseek_v4_mega_moe_inputs(
|
|
hidden_states,
|
|
topk_weights,
|
|
topk_ids,
|
|
symm_buffer.x[:num_tokens],
|
|
symm_buffer.x_sf[:num_tokens],
|
|
symm_buffer.topk_idx[:num_tokens],
|
|
symm_buffer.topk_weights[:num_tokens],
|
|
)
|
|
|
|
# This method must have been already called during the weight loading phase.
|
|
# We call it again here to cover the dummy weight loading case.
|
|
self.finalize_weights()
|
|
|
|
assert self._transformed_l1_weights is not None
|
|
assert self._transformed_l2_weights is not None
|
|
deep_gemm.fp8_fp4_mega_moe(
|
|
y,
|
|
self._transformed_l1_weights,
|
|
self._transformed_l2_weights,
|
|
symm_buffer,
|
|
activation_clamp=activation_clamp,
|
|
fast_math=fast_math,
|
|
)
|
|
|
|
|
|
DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
|
|
|
|
|
|
def _deepseek_v4_mega_moe_experts_op(
|
|
hidden_states: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
out: torch.Tensor,
|
|
layer_name: str,
|
|
activation_clamp: float | None,
|
|
fast_math: bool,
|
|
) -> None:
|
|
self = get_forward_context().no_compile_layers[layer_name]
|
|
self._run_mega_moe(
|
|
hidden_states,
|
|
topk_weights,
|
|
topk_ids,
|
|
out,
|
|
activation_clamp,
|
|
fast_math,
|
|
)
|
|
|
|
|
|
def _deepseek_v4_mega_moe_experts_op_fake(
|
|
hidden_states: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
out: torch.Tensor,
|
|
layer_name: str,
|
|
activation_clamp: float | None,
|
|
fast_math: bool,
|
|
) -> None:
|
|
return None
|
|
|
|
|
|
direct_register_custom_op(
|
|
op_name="deepseek_v4_mega_moe_experts",
|
|
op_func=_deepseek_v4_mega_moe_experts_op,
|
|
mutates_args=["out"],
|
|
fake_impl=_deepseek_v4_mega_moe_experts_op_fake,
|
|
)
|
|
|
|
|
|
class DeepseekV4MoE(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
):
|
|
super().__init__()
|
|
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
self.prefix = prefix
|
|
self.use_mega_moe = (
|
|
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
|
|
)
|
|
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
|
|
raise NotImplementedError(
|
|
"DeepSeek V4 MegaMoE currently requires expert parallel. "
|
|
"Enable it with --enable-expert-parallel, or pick a different "
|
|
"moe backend."
|
|
)
|
|
|
|
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.n_routed_experts = config.n_routed_experts
|
|
self.n_activated_experts = config.num_experts_per_tok
|
|
self.moe_intermediate_size = config.moe_intermediate_size
|
|
self.swiglu_limit = config.swiglu_limit
|
|
self.renormalize = config.norm_topk_prob
|
|
self.scoring_func = getattr(config, "scoring_func", "sqrtsoftplus")
|
|
if self.use_mega_moe and self.scoring_func != "sqrtsoftplus":
|
|
raise NotImplementedError(
|
|
"DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only."
|
|
)
|
|
if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4":
|
|
raise NotImplementedError(
|
|
"DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype="
|
|
f"{config.expert_dtype!r}. Drop --kernel-config moe_backend="
|
|
"deep_gemm_mega_moe for this checkpoint."
|
|
)
|
|
|
|
self.gate = GateLinear(
|
|
config.hidden_size,
|
|
config.n_routed_experts,
|
|
out_dtype=torch.float32,
|
|
bias=False,
|
|
prefix=f"{prefix}.gate",
|
|
)
|
|
self.gate.e_score_correction_bias = None
|
|
self.gate.tid2eid = None
|
|
is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers
|
|
self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32
|
|
|
|
if is_hash_moe:
|
|
# hash MoE doesn't use e_score_correction_bias
|
|
# Use randint instead of empty to avoid garbage values causing
|
|
# invalid memory access in dummy mode (--load-format="dummy")
|
|
self.gate.tid2eid = nn.Parameter(
|
|
torch.randint(
|
|
0,
|
|
config.n_routed_experts,
|
|
(config.vocab_size, config.num_experts_per_tok),
|
|
dtype=self.hash_indices_dtype,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
elif getattr(config, "topk_method", None) == "noaux_tc":
|
|
self.gate.e_score_correction_bias = nn.Parameter(
|
|
torch.empty(config.n_routed_experts, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
|
|
if config.n_shared_experts is None:
|
|
self.shared_experts = None
|
|
else:
|
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
|
|
|
self.shared_experts = DeepseekV4MLP(
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
swiglu_limit=self.swiglu_limit,
|
|
quant_config=quant_config,
|
|
reduce_results=self.use_mega_moe,
|
|
prefix=f"{prefix}.shared_experts",
|
|
)
|
|
|
|
if self.use_mega_moe:
|
|
self._init_mega_moe_experts(vllm_config, config, prefix)
|
|
else:
|
|
self._init_fused_moe_experts(config, quant_config, prefix)
|
|
|
|
def _init_mega_moe_experts(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
config,
|
|
prefix: str,
|
|
) -> None:
|
|
self.ep_group = get_ep_group()
|
|
self.ep_size = self.ep_group.world_size
|
|
self.ep_rank = self.ep_group.rank_in_group
|
|
assert config.n_routed_experts % self.ep_size == 0
|
|
|
|
self.n_local_experts = config.n_routed_experts // self.ep_size
|
|
self.experts_start_idx = self.ep_rank * self.n_local_experts
|
|
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
|
|
|
self.experts = DeepseekV4MegaMoEExperts(
|
|
vllm_config,
|
|
num_experts=config.n_routed_experts,
|
|
num_local_experts=self.n_local_experts,
|
|
experts_start_idx=self.experts_start_idx,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
prefix=f"{prefix}.experts",
|
|
)
|
|
|
|
def _init_fused_moe_experts(
|
|
self,
|
|
config,
|
|
quant_config,
|
|
prefix: str,
|
|
) -> None:
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
assert config.n_routed_experts % self.tp_size == 0
|
|
|
|
self.n_local_experts = config.n_routed_experts // self.tp_size
|
|
self.experts_start_idx = self.tp_rank * self.n_local_experts
|
|
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
|
|
|
self.experts = FusedMoE(
|
|
shared_experts=self.shared_experts,
|
|
gate=self.gate,
|
|
num_experts=config.n_routed_experts,
|
|
top_k=config.num_experts_per_tok,
|
|
hidden_size=config.hidden_size,
|
|
intermediate_size=config.moe_intermediate_size,
|
|
renormalize=config.norm_topk_prob,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.experts",
|
|
scoring_func=self.scoring_func,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
e_score_correction_bias=self.gate.e_score_correction_bias,
|
|
hash_indices_table=self.gate.tid2eid,
|
|
swiglu_limit=self.swiglu_limit,
|
|
router_logits_dtype=torch.float32,
|
|
)
|
|
|
|
def forward(
|
|
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
|
|
) -> torch.Tensor:
|
|
if self.gate.tid2eid is not None and input_ids is None:
|
|
raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.")
|
|
|
|
if not self.use_mega_moe:
|
|
return self._forward_fused_moe(hidden_states, input_ids)
|
|
|
|
org_shape = hidden_states.shape
|
|
router_logits, _ = self.gate(hidden_states)
|
|
topk_weights, topk_ids = fused_topk_bias(
|
|
hidden_states=hidden_states,
|
|
gating_output=router_logits,
|
|
scoring_func=self.scoring_func,
|
|
e_score_correction_bias=self.gate.e_score_correction_bias.data
|
|
if self.gate.e_score_correction_bias is not None
|
|
else None,
|
|
topk=self.n_activated_experts,
|
|
renormalize=self.renormalize,
|
|
indices_type=self.hash_indices_dtype,
|
|
input_tokens=input_ids,
|
|
hash_indices_table=self.gate.tid2eid,
|
|
routed_scaling_factor=self.routed_scaling_factor,
|
|
)
|
|
activation_clamp = (
|
|
float(self.swiglu_limit) if self.swiglu_limit is not None else None
|
|
)
|
|
final_hidden_states = self.experts(
|
|
hidden_states,
|
|
topk_weights,
|
|
topk_ids,
|
|
activation_clamp=activation_clamp,
|
|
)
|
|
|
|
if self.shared_experts is not None:
|
|
shared_output = self.shared_experts(hidden_states)
|
|
final_hidden_states += shared_output
|
|
|
|
return final_hidden_states.view(org_shape)
|
|
|
|
def _forward_fused_moe(
|
|
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
|
|
) -> torch.Tensor:
|
|
org_shape = hidden_states.shape
|
|
if self.experts.is_internal_router:
|
|
# In this case, the gate/router runs inside the FusedMoE class
|
|
final_hidden_states = self.experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=hidden_states,
|
|
input_ids=input_ids,
|
|
)
|
|
else:
|
|
router_logits, _ = self.gate(hidden_states)
|
|
final_hidden_states = self.experts(
|
|
hidden_states=hidden_states,
|
|
router_logits=router_logits,
|
|
input_ids=input_ids,
|
|
)
|
|
|
|
return final_hidden_states.view(org_shape)
|
|
|
|
def finalize_mega_moe_weights(self) -> None:
|
|
if self.use_mega_moe:
|
|
self.experts.finalize_weights()
|
|
|
|
|
|
class DeepseekV4Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vllm_config: VllmConfig,
|
|
prefix: str,
|
|
topk_indices_buffer: torch.Tensor | None = None,
|
|
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
|
):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
layer_id = extract_layer_index(prefix)
|
|
|
|
self.layer_id = layer_id
|
|
self.hidden_size = config.hidden_size
|
|
self.n_heads = config.num_attention_heads
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
assert self.n_heads % tp_size == 0
|
|
|
|
self.n_local_heads = self.n_heads // tp_size
|
|
self.q_lora_rank = config.q_lora_rank
|
|
self.o_lora_rank = config.o_lora_rank
|
|
self.head_dim = config.head_dim
|
|
self.rope_head_dim = config.qk_rope_head_dim
|
|
self.nope_head_dim = self.head_dim - self.rope_head_dim
|
|
self.n_groups = config.o_groups
|
|
self.n_local_groups = self.n_groups // tp_size
|
|
self.window_size = config.sliding_window
|
|
# NOTE(zyongye) Compress ratio can't be 0
|
|
# we do this for because MTP layer is not included
|
|
# in the compress ratio list
|
|
if layer_id < config.num_hidden_layers:
|
|
self.compress_ratio = max(1, config.compress_ratios[layer_id])
|
|
else:
|
|
self.compress_ratio = 1
|
|
self.eps = config.rms_norm_eps
|
|
self.max_position_embeddings = config.max_position_embeddings
|
|
|
|
# Padded to min 64 heads for FlashMLA, initialized to -inf
|
|
# (no sink effect). Weight loading fills the first n_local_heads slots.
|
|
padded_heads = max(self.n_local_heads, 64)
|
|
self.attn_sink = nn.Parameter(
|
|
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
|
|
self.fused_wqa_wkv = MergedColumnParallelLinear(
|
|
self.hidden_size,
|
|
[self.q_lora_rank, self.head_dim],
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.fused_wqa_wkv",
|
|
disable_tp=True, # fused ReplicatedLinear
|
|
)
|
|
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
|
|
self.wq_b = ColumnParallelLinear(
|
|
self.q_lora_rank,
|
|
self.n_heads * self.head_dim,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
return_bias=False,
|
|
prefix=f"{prefix}.wq_b",
|
|
)
|
|
|
|
self.kv_norm = RMSNorm(self.head_dim, self.eps)
|
|
# wo_a is NOT quantized in the NVFP4 checkpoint (modelopt left it as bfloat16),
|
|
# but the attention forward pass expects FP8 (weight + weight_scale_inv).
|
|
# Pass quant_config=None to load bfloat16, then process_weights_after_loading
|
|
# will handle the FP8 quantization.
|
|
self.wo_a = ColumnParallelLinear(
|
|
self.n_heads * self.head_dim // self.n_groups,
|
|
self.n_groups * self.o_lora_rank,
|
|
bias=False,
|
|
quant_config=None,
|
|
return_bias=False,
|
|
prefix=f"{prefix}.wo_a",
|
|
)
|
|
self.wo_a.is_bmm = True
|
|
self.wo_a.bmm_batch_size = self.n_local_groups
|
|
self.wo_b = RowParallelLinear(
|
|
self.n_groups * self.o_lora_rank,
|
|
self.hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
return_bias=False,
|
|
prefix=f"{prefix}.wo_b",
|
|
)
|
|
self.softmax_scale = self.head_dim**-0.5
|
|
self.scale_fmt = config.quantization_config["scale_fmt"]
|
|
|
|
self.rope_parameters = config.rope_scaling
|
|
|
|
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
|
|
rope_parameters = config.rope_parameters
|
|
rope_parameters["rope_theta"] = (
|
|
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
|
|
)
|
|
if config.rope_parameters["rope_type"] != "default":
|
|
config.rope_parameters["rope_type"] = (
|
|
"deepseek_yarn"
|
|
if config.rope_parameters.get("apply_yarn_scaling", True)
|
|
else "deepseek_llama_scaling"
|
|
)
|
|
rope_parameters["mscale"] = 0 # Disable mscale
|
|
rope_parameters["mscale_all_dim"] = 0 # Disable mscale
|
|
rope_parameters["is_deepseek_v4"] = True
|
|
rope_parameters["rope_dim"] = self.rope_head_dim
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
max_position=self.max_position_embeddings,
|
|
rope_parameters=rope_parameters,
|
|
is_neox_style=False,
|
|
)
|
|
|
|
self.indexer = None
|
|
if self.compress_ratio == 4:
|
|
# Only C4A uses sparse attention and hence has indexer.
|
|
self.indexer = DeepseekV4Indexer(
|
|
vllm_config,
|
|
config=config,
|
|
hidden_size=self.hidden_size,
|
|
q_lora_rank=self.q_lora_rank,
|
|
quant_config=quant_config,
|
|
cache_config=vllm_config.cache_config,
|
|
topk_indices_buffer=topk_indices_buffer,
|
|
compress_ratio=self.compress_ratio,
|
|
prefix=f"{prefix}.indexer",
|
|
)
|
|
|
|
mla_modules = DeepseekV4MLAModules(
|
|
vllm_config=vllm_config,
|
|
fused_wqa_wkv=self.fused_wqa_wkv,
|
|
q_norm=self.q_norm,
|
|
wq_b=self.wq_b,
|
|
kv_norm=self.kv_norm,
|
|
wo_a=self.wo_a,
|
|
wo_b=self.wo_b,
|
|
attn_sink=self.attn_sink,
|
|
rotary_emb=self.rotary_emb,
|
|
indexer=self.indexer,
|
|
indexer_rotary_emb=self.rotary_emb,
|
|
topk_indices_buffer=topk_indices_buffer,
|
|
aux_stream_list=aux_stream_list,
|
|
)
|
|
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=self.n_local_heads,
|
|
head_dim=self.head_dim,
|
|
scale=self.softmax_scale,
|
|
qk_nope_head_dim=self.nope_head_dim,
|
|
qk_rope_head_dim=self.rope_head_dim,
|
|
v_head_dim=self.head_dim,
|
|
q_lora_rank=self.q_lora_rank,
|
|
kv_lora_rank=self.head_dim,
|
|
o_lora_rank=self.o_lora_rank,
|
|
mla_modules=mla_modules,
|
|
window_size=self.window_size,
|
|
compress_ratio=self.compress_ratio,
|
|
cache_config=vllm_config.cache_config,
|
|
quant_config=quant_config,
|
|
prefix=prefix,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
llama_4_scaling: torch.Tensor | None,
|
|
):
|
|
return self.mla_attn(positions, hidden_states, llama_4_scaling)
|
|
|
|
|
|
class DeepseekV4DecoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
vllm_config,
|
|
prefix,
|
|
topk_indices_buffer: torch.Tensor | None = None,
|
|
aux_stream_list: list[torch.cuda.Stream] | None = None,
|
|
):
|
|
super().__init__()
|
|
|
|
# Lazy import to avoid top-level tilelang dependency.
|
|
# Registers both torch.ops.vllm.mhc_pre and mhc_post
|
|
import vllm.model_executor.layers.mhc # noqa: F401
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
self.hidden_size = config.hidden_size
|
|
|
|
self.rms_norm_eps = config.rms_norm_eps
|
|
self.attn = DeepseekV4Attention(
|
|
vllm_config,
|
|
prefix=f"{prefix}.attn",
|
|
topk_indices_buffer=topk_indices_buffer,
|
|
aux_stream_list=aux_stream_list,
|
|
)
|
|
self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")
|
|
|
|
self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
|
|
self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
|
|
self.hc_mult = config.hc_mult
|
|
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
|
|
self.hc_eps = config.hc_eps
|
|
self.hc_post_alpha = 2.0
|
|
mix_hc = (2 + self.hc_mult) * self.hc_mult
|
|
hc_dim = self.hc_mult * self.hidden_size
|
|
self.hc_attn_fn = nn.Parameter(
|
|
torch.empty(
|
|
(mix_hc, hc_dim),
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
self.hc_ffn_fn = nn.Parameter(
|
|
torch.empty(
|
|
(mix_hc, hc_dim),
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
self.hc_attn_base = nn.Parameter(
|
|
torch.empty(
|
|
mix_hc,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
self.hc_ffn_base = nn.Parameter(
|
|
torch.empty(
|
|
mix_hc,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
self.hc_attn_scale = nn.Parameter(
|
|
torch.empty(
|
|
3,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
self.hc_ffn_scale = nn.Parameter(
|
|
torch.empty(
|
|
3,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
|
|
def hc_pre(
|
|
self,
|
|
x: torch.Tensor,
|
|
hc_fn: torch.Tensor,
|
|
hc_scale: torch.Tensor,
|
|
hc_base: torch.Tensor,
|
|
):
|
|
post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre(
|
|
residual=x,
|
|
fn=hc_fn,
|
|
hc_scale=hc_scale,
|
|
hc_base=hc_base,
|
|
rms_eps=self.rms_norm_eps,
|
|
hc_pre_eps=self.hc_eps,
|
|
hc_sinkhorn_eps=self.hc_eps,
|
|
hc_post_mult_value=self.hc_post_alpha,
|
|
sinkhorn_repeat=self.hc_sinkhorn_iters,
|
|
)
|
|
return layer_input, post_mix, res_mix
|
|
|
|
def hc_post(
|
|
self,
|
|
x: torch.Tensor,
|
|
residual: torch.Tensor,
|
|
post: torch.Tensor,
|
|
comb: torch.Tensor,
|
|
):
|
|
return torch.ops.vllm.mhc_post(x, residual, post, comb)
|
|
|
|
def forward(
|
|
self,
|
|
x: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
input_ids: torch.Tensor | None,
|
|
post_mix: torch.Tensor | None,
|
|
res_mix: torch.Tensor | None,
|
|
residual: torch.Tensor | None,
|
|
) -> torch.Tensor:
|
|
if residual is None:
|
|
# Run standalone hc_pre on first layer
|
|
residual = x
|
|
x, post_mix, res_mix = self.hc_pre(
|
|
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
|
|
)
|
|
else:
|
|
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
|
|
x,
|
|
residual,
|
|
post_mix,
|
|
res_mix,
|
|
self.hc_attn_fn,
|
|
self.hc_attn_scale,
|
|
self.hc_attn_base,
|
|
self.rms_norm_eps,
|
|
self.hc_eps,
|
|
self.hc_eps,
|
|
self.hc_post_alpha,
|
|
self.hc_sinkhorn_iters,
|
|
)
|
|
|
|
x = self.attn_norm(x)
|
|
x = self.attn(positions, x, None)
|
|
|
|
residual, post_mix, res_mix, x = torch.ops.vllm.mhc_fused_post_pre(
|
|
x,
|
|
residual,
|
|
post_mix,
|
|
res_mix,
|
|
self.hc_ffn_fn,
|
|
self.hc_ffn_scale,
|
|
self.hc_ffn_base,
|
|
self.rms_norm_eps,
|
|
self.hc_eps,
|
|
self.hc_eps,
|
|
self.hc_post_alpha,
|
|
self.hc_sinkhorn_iters,
|
|
)
|
|
|
|
x = self.ffn_norm(x)
|
|
x = self.ffn(x, input_ids)
|
|
return x, residual, post_mix, res_mix
|
|
|
|
|
|
@support_torch_compile
|
|
class DeepseekV4Model(nn.Module):
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
# Select weight mapper based on quantization method.
|
|
# NVFP4 (modelopt_fp4) checkpoints use different key naming
|
|
# than the default MXFP4 format.
|
|
quant_config = vllm_config.quant_config
|
|
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
|
|
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
|
|
elif getattr(config, "expert_dtype", "fp4") != "fp4":
|
|
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
|
|
else:
|
|
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
|
|
quant_config = vllm_config.quant_config
|
|
self.config = config
|
|
self.use_mega_moe = (
|
|
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
|
|
)
|
|
if self.use_mega_moe and not vllm_config.parallel_config.enable_expert_parallel:
|
|
raise NotImplementedError(
|
|
"DeepSeek V4 MegaMoE currently requires expert parallel. "
|
|
"Enable it with --enable-expert-parallel, or pick a different "
|
|
"moe backend."
|
|
)
|
|
self.vocab_size = config.vocab_size
|
|
self.hc_eps = config.hc_eps
|
|
self.hc_mult = config.hc_mult
|
|
self.hc_dim = self.hc_mult * config.hidden_size
|
|
self.rms_norm_eps = config.rms_norm_eps
|
|
|
|
# Three aux streams: one per non-default input GEMM in
|
|
# DeepseekV4MultiHeadLatentAttentionWrapper.attn_gemm_parallel_execute
|
|
# (compressor kv_score, indexer.weights_proj, indexer.compressor
|
|
# kv_score). fused_wqa_wkv stays on the default stream.
|
|
# Disable them on ROCm because of hang issues.
|
|
aux_stream_list = (
|
|
None
|
|
if current_platform.is_rocm()
|
|
else [torch.cuda.Stream() for _ in range(3)]
|
|
)
|
|
|
|
self.device = current_platform.device_type
|
|
# Reserved topk indices buffer for all Indexer layers to reuse.
|
|
self.topk_indices_buffer = torch.empty(
|
|
vllm_config.scheduler_config.max_num_batched_tokens,
|
|
config.index_topk,
|
|
dtype=torch.int32,
|
|
device=self.device,
|
|
)
|
|
|
|
if get_pp_group().is_first_rank:
|
|
self.embed_tokens = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.embed_tokens",
|
|
)
|
|
else:
|
|
self.embed_tokens = PPMissingLayer()
|
|
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: DeepseekV4DecoderLayer(
|
|
vllm_config,
|
|
prefix=prefix,
|
|
topk_indices_buffer=self.topk_indices_buffer,
|
|
aux_stream_list=aux_stream_list,
|
|
),
|
|
prefix=f"{prefix}.layers",
|
|
)
|
|
|
|
if get_pp_group().is_last_rank:
|
|
self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps)
|
|
else:
|
|
self.norm = PPMissingLayer()
|
|
|
|
self.hc_head_fn = nn.Parameter(
|
|
torch.empty(
|
|
self.hc_mult,
|
|
self.hc_dim,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
self.hc_head_base = nn.Parameter(
|
|
torch.empty(
|
|
self.hc_mult,
|
|
dtype=torch.float32,
|
|
),
|
|
requires_grad=False,
|
|
)
|
|
self.hc_head_scale = nn.Parameter(
|
|
torch.empty(1, dtype=torch.float32),
|
|
requires_grad=False,
|
|
)
|
|
|
|
# Pre-hc_head residual stream buffer for the MTP draft. Stable
|
|
# address (outside the cudagraph pool) so the copy_ in forward()
|
|
# refreshes it correctly across captured shapes.
|
|
# refreshes it correctly across captured shapes. Only allocated on
|
|
# the last PP rank — that's where MTP target hidden states are
|
|
# produced.
|
|
if get_pp_group().is_last_rank:
|
|
self._mtp_hidden_buffer = torch.empty(
|
|
vllm_config.scheduler_config.max_num_batched_tokens,
|
|
self.hc_dim,
|
|
dtype=vllm_config.model_config.dtype,
|
|
device=self.device,
|
|
)
|
|
else:
|
|
self._mtp_hidden_buffer = None
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.embed_tokens(input_ids)
|
|
|
|
def make_empty_intermediate_tensors(
|
|
self,
|
|
batch_size: int,
|
|
dtype: torch.dtype,
|
|
device: torch.device,
|
|
) -> IntermediateTensors:
|
|
# PP intermediate tensors carry the multi-stream hidden_states
|
|
# of shape (num_tokens, hc_mult, hidden_size) — V4 expands the
|
|
# token embedding to hc_mult streams before the first decoder
|
|
# layer and keeps that shape until hc_head() collapses it.
|
|
return IntermediateTensors(
|
|
{
|
|
"hidden_states": torch.zeros(
|
|
(batch_size, self.hc_mult, self.config.hidden_size),
|
|
dtype=dtype,
|
|
device=device,
|
|
),
|
|
}
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.embed_input_ids(input_ids)
|
|
hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1)
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
|
|
if self.use_mega_moe:
|
|
input_ids = input_ids.to(torch.int64)
|
|
|
|
residual, post_mix, res_mix = None, None, None
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
hidden_states, residual, post_mix, res_mix = layer(
|
|
hidden_states,
|
|
positions,
|
|
input_ids,
|
|
post_mix,
|
|
res_mix,
|
|
residual,
|
|
)
|
|
else:
|
|
hidden_states = layer.hc_post(hidden_states, residual, post_mix, res_mix)
|
|
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({"hidden_states": hidden_states})
|
|
|
|
# Stash pre-hc_head residual for the MTP draft (captured copy_).
|
|
num_tokens = hidden_states.shape[0]
|
|
self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1))
|
|
|
|
hidden_states = hc_head(
|
|
hidden_states,
|
|
self.hc_head_fn,
|
|
self.hc_head_scale,
|
|
self.hc_head_base,
|
|
self.rms_norm_eps,
|
|
self.hc_eps,
|
|
)
|
|
hidden_states = self.norm(hidden_states)
|
|
return hidden_states
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "w1", 0),
|
|
("gate_up_proj", "w3", 1),
|
|
("attn.fused_wqa_wkv", "attn.wq_a", 0),
|
|
("attn.fused_wqa_wkv", "attn.wkv", 1),
|
|
("compressor.fused_wkv_wgate", "compressor.wkv", 0),
|
|
("compressor.fused_wkv_wgate", "compressor.wgate", 1),
|
|
# Indexer's compressor (same stacking pattern)
|
|
("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wkv", 0),
|
|
("indexer.compressor.fused_wkv_wgate", "indexer.compressor.wgate", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
|
|
# TP for attention
|
|
tp_size = get_tensor_model_parallel_world_size()
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
n_head = self.config.num_attention_heads
|
|
n_local_head = n_head // tp_size
|
|
head_rank_start = n_local_head * tp_rank
|
|
head_rank_end = n_local_head * (tp_rank + 1)
|
|
|
|
# Pre-compute expert mapping ONCE.
|
|
expert_mapping = self.get_expert_mapping()
|
|
|
|
# NVFP4 compressor/indexer scale params need special handling:
|
|
# wkv.input_scale (shape [1]) + wgate.input_scale (shape [1])
|
|
# must be concatenated into fused_wkv_wgate.input_scale (shape [2]).
|
|
# The default stacking path fails because PerTensorScaleParameter's
|
|
# weight_loader asserts shape equality.
|
|
# We buffer them and load once both shards are available.
|
|
compressor_scale_buffer: dict[str, dict[int, torch.Tensor]] = {}
|
|
|
|
for name, loaded_weight in weights:
|
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
# Skip non-stacked layers and experts (experts handled below).
|
|
if ".experts." in name:
|
|
continue
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
|
|
if is_pp_missing_parameter(name, self):
|
|
break
|
|
if name not in params_dict:
|
|
# The stacked param doesn't exist — skip
|
|
# (e.g. indexer.compressor.fused_wkv_wgate on layers
|
|
# that don't have the full indexer structure)
|
|
break
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
|
|
# NVFP4 scale params for stacked fused_wkv_wgate need
|
|
# special handling: each shard (wkv, wgate) has scale
|
|
# shape [1] or [head_dim, K], but the fused param has
|
|
# shape [2] or [2*head_dim, K]. The default stacking
|
|
# weight_loader can't handle this for PerTensorScale or
|
|
# ModelWeight scale params. Buffer and concatenate.
|
|
is_compressor_scale = (
|
|
"fused_wkv_wgate" in name
|
|
and name.endswith((
|
|
"input_scale",
|
|
"weight_scale",
|
|
"weight_scale_2",
|
|
))
|
|
)
|
|
if is_compressor_scale:
|
|
# Verify the fused param exists before buffering
|
|
if name not in params_dict:
|
|
print(
|
|
f"COMPRESSOR_SCALE_SKIP: {name} not in params_dict",
|
|
flush=True,
|
|
)
|
|
break
|
|
if is_compressor_scale:
|
|
# Buffer the shard for later concatenation
|
|
if name not in compressor_scale_buffer:
|
|
compressor_scale_buffer[name] = {}
|
|
compressor_scale_buffer[name][shard_id] = loaded_weight
|
|
loaded_params.add(name)
|
|
break
|
|
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
loaded_params.add(name)
|
|
break
|
|
else:
|
|
if ".experts." in name:
|
|
# E8M0 scales are stored as float8_e8m0fnu in
|
|
# checkpoints but the MoE param is uint8. copy_()
|
|
# would do a numeric conversion (e.g. 2^-7 → 0),
|
|
# destroying the raw exponent bytes.
|
|
if (
|
|
"weight_scale" in name
|
|
and loaded_weight.dtype == torch.float8_e8m0fnu
|
|
):
|
|
loaded_weight = loaded_weight.view(torch.uint8)
|
|
for mapping in expert_mapping:
|
|
param_name, weight_name, expert_id, shard_id = mapping
|
|
if weight_name not in name:
|
|
continue
|
|
name_mapped = name.replace(weight_name, param_name)
|
|
if is_pp_missing_parameter(name_mapped, self):
|
|
continue
|
|
param = params_dict[name_mapped]
|
|
# We should ask the weight loader to return success or not
|
|
# here since otherwise we may skip experts with other
|
|
# available replicas.
|
|
weight_loader = typing.cast(
|
|
Callable[..., bool], param.weight_loader
|
|
)
|
|
success = weight_loader(
|
|
param,
|
|
loaded_weight,
|
|
name_mapped,
|
|
shard_id=shard_id,
|
|
expert_id=expert_id,
|
|
return_success=True,
|
|
)
|
|
if success:
|
|
name = name_mapped
|
|
break
|
|
loaded_params.add(name_mapped)
|
|
continue
|
|
elif "attn_sink" in name:
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
narrow_weight = loaded_weight[head_rank_start:head_rank_end]
|
|
n = narrow_weight.shape[0]
|
|
params_dict[name][:n].copy_(narrow_weight)
|
|
loaded_params.add(name)
|
|
continue
|
|
else:
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
if name not in params_dict:
|
|
print(f"Skipping weight {name} (not in model params)",
|
|
flush=True)
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(
|
|
param, "weight_loader", default_weight_loader
|
|
)
|
|
try:
|
|
weight_loader(param, loaded_weight)
|
|
except (AssertionError, RuntimeError) as e:
|
|
print(
|
|
f"WEIGHT_LOAD_FAIL: name={name} "
|
|
f"param_shape={param.data.shape if hasattr(param, 'data') else '?'} "
|
|
f"loaded_shape={loaded_weight.shape} "
|
|
f"loaded_dtype={loaded_weight.dtype} "
|
|
f"error={e}",
|
|
flush=True,
|
|
)
|
|
raise
|
|
|
|
# Load buffered compressor/indexer scale params.
|
|
# These are NVFP4 quantization scales that need concatenation
|
|
# across shards (wkv=shard0, wgate=shard1) before loading.
|
|
for name, shards in compressor_scale_buffer.items():
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
if len(shards) == 2:
|
|
# Concatenate shard 0 and shard 1 along dim 0.
|
|
# Scales may be 0-dim scalars (input_scale, weight_scale_2)
|
|
# or N-dim tensors (weight_scale); reshape scalars to 1-d.
|
|
s0, s1 = shards[0], shards[1]
|
|
if s0.ndim == 0:
|
|
s0 = s0.reshape(1)
|
|
if s1.ndim == 0:
|
|
s1 = s1.reshape(1)
|
|
stacked = torch.cat([s0, s1], dim=0)
|
|
else:
|
|
stacked = shards[0]
|
|
assert param.data.shape == stacked.shape, (
|
|
f"Scale shape mismatch for {name}: "
|
|
f"param={param.data.shape} loaded={stacked.shape}"
|
|
)
|
|
param.data.copy_(stacked)
|
|
|
|
return loaded_params
|
|
|
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
|
first_layer = next(iter(islice(self.layers, self.start_layer, self.end_layer)))
|
|
if first_layer.ffn.use_mega_moe:
|
|
return make_deepseek_v4_expert_params_mapping(self.config.n_routed_experts)
|
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
# (param_name, weight_name, expert_id, shard_id)
|
|
return FusedMoE.make_expert_params_mapping(
|
|
self,
|
|
ckpt_gate_proj_name="w1",
|
|
ckpt_down_proj_name="w2",
|
|
ckpt_up_proj_name="w3",
|
|
num_experts=self.config.n_routed_experts,
|
|
)
|
|
|
|
def finalize_mega_moe_weights(self) -> None:
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
layer.ffn.finalize_mega_moe_weights()
|
|
# Quantize wo_a to FP8 (checkpoint has bfloat16, forward expects FP8)
|
|
attn = layer.attn
|
|
if hasattr(attn, 'wo_a') and attn.wo_a.weight.dtype == torch.bfloat16:
|
|
self._quantize_wo_a_to_fp8(attn.wo_a)
|
|
|
|
@staticmethod
|
|
def _quantize_wo_a_to_fp8(wo_a: ColumnParallelLinear) -> None:
|
|
"""Quantize wo_a weight from bfloat16 to float8_e4m3fn.
|
|
|
|
The attention forward pass (fused_inv_rope_fp8_quant + einsum)
|
|
expects wo_a.weight as FP8 and wo_a.weight_scale_inv as float32.
|
|
The NVFP4 checkpoint stores wo_a as bfloat16, so we quantize here.
|
|
Uses per-tensor symmetric quantization (same as modelopt FP8).
|
|
"""
|
|
weight_bf16 = wo_a.weight.data
|
|
# Per-tensor FP8 quantization: scale = amax / fp8_max
|
|
fp8_max = torch.finfo(torch.float8_e4m3fn).max # 448.0
|
|
amax = weight_bf16.abs().max().float()
|
|
scale = amax / fp8_max
|
|
# Avoid division by zero
|
|
if scale == 0:
|
|
scale = torch.tensor(1.0, device=scale.device)
|
|
scale_inv = 1.0 / scale
|
|
weight_fp8 = (weight_bf16.float() * scale).to(torch.float8_e4m3fn)
|
|
wo_a.weight = torch.nn.Parameter(weight_fp8, requires_grad=False)
|
|
wo_a.weight_scale_inv = torch.nn.Parameter(
|
|
scale_inv.clone(), requires_grad=False
|
|
)
|
|
|
|
|
|
@torch.compile(backend=current_platform.simple_compile_backend)
|
|
def hc_head(
|
|
hidden_states: torch.Tensor,
|
|
hc_fn: torch.Tensor,
|
|
hc_scale: torch.Tensor,
|
|
hc_base: torch.Tensor,
|
|
rms_norm_eps: float,
|
|
hc_eps: float,
|
|
) -> torch.Tensor:
|
|
hc_mult, hidden_size = hidden_states.shape[-2:]
|
|
outer_shape = hidden_states.shape[:-2]
|
|
hs_flat = hidden_states.view(-1, hc_mult, hidden_size)
|
|
num_tokens = hs_flat.shape[0]
|
|
out = torch.empty(
|
|
num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device
|
|
)
|
|
torch.ops.vllm.hc_head_fused_kernel(
|
|
hs_flat,
|
|
hc_fn,
|
|
hc_scale,
|
|
hc_base,
|
|
out,
|
|
hidden_size,
|
|
rms_norm_eps,
|
|
hc_eps,
|
|
hc_mult,
|
|
)
|
|
return out.view(*outer_shape, hidden_size)
|
|
|
|
|
|
def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper:
|
|
if expert_dtype == "fp4":
|
|
# MXFP4 experts use Mxfp4MoEMethod, which registers scales as
|
|
# ``w{1,2,3}_weight_scale`` (no _inv suffix). FP8 linear and
|
|
# shared experts use Fp8LinearMethod's block scales, which
|
|
# register as ``weight_scale_inv``.
|
|
scale_regex = {
|
|
re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale",
|
|
re.compile(r"\.scale$"): ".weight_scale_inv",
|
|
}
|
|
else:
|
|
# FP8 experts use Fp8MoEMethod (block_quant=True), which registers
|
|
# scales as ``w{13,2}_weight_scale_inv``. Map all ``.scale`` keys
|
|
# there.
|
|
scale_regex = {
|
|
re.compile(r"\.scale$"): ".weight_scale_inv",
|
|
}
|
|
return WeightsMapper(
|
|
orig_to_new_prefix={
|
|
"layers.": "model.layers.",
|
|
"embed.": "model.embed.",
|
|
"norm.": "model.norm.",
|
|
"hc_head": "model.hc_head",
|
|
"mtp.": "model.mtp.",
|
|
},
|
|
orig_to_new_regex=scale_regex,
|
|
orig_to_new_suffix={
|
|
"head.weight": "lm_head.weight",
|
|
"embed.weight": "embed_tokens.weight",
|
|
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
|
|
},
|
|
orig_to_new_substr={
|
|
".attn.compressor.": ".attn.mla_attn.compressor.",
|
|
".shared_experts.w2": ".shared_experts.down_proj",
|
|
},
|
|
)
|
|
|
|
|
|
|
|
|
|
def _make_deepseek_v4_nvfp4_weights_mapper() -> WeightsMapper:
|
|
"""Weight mapper for NVFP4 (ModelOpt) DeepSeek-V4 checkpoints.
|
|
|
|
NVFP4 checkpoints use different key naming than the upstream MXFP4 format:
|
|
- Expert weights: gate_proj/up_proj/down_proj (not w1/w3/w2)
|
|
- Scales already have .weight_scale / .weight_scale_2 / .input_scale suffixes
|
|
- Shared expert uses down_proj (not w2)
|
|
- Self-attention uses .self_attn. prefix (same as checkpoint, renamed to .attn.)
|
|
|
|
This is the mapper that should be used when quantization is modelopt_fp4.
|
|
"""
|
|
expert_rename_regex = {
|
|
re.compile(r"(\.experts\.\d+\.)gate_proj\."): r"\1w1.",
|
|
re.compile(r"(\.experts\.\d+\.)up_proj\."): r"\1w3.",
|
|
re.compile(r"(\.experts\.\d+\.)down_proj\."): r"\1w2.",
|
|
}
|
|
|
|
suffix_renames = {
|
|
# The NVFP4 checkpoint already uses lm_head / embed_tokens directly,
|
|
# no suffix renames needed (unlike the MXFP4 upstream format).
|
|
}
|
|
|
|
# NOTE: specific renames MUST come before general ones (applied in order)
|
|
substr_renames = {
|
|
# === Indexer params (MUST come before .self_attn.compressor.
|
|
# so that indexer keys are captured before the compressor prefix
|
|
# rewrite moves them under mla_attn.compressor) ===
|
|
# The checkpoint puts indexer under self_attn.compressor.indexer.*
|
|
# but the model has indexer at attn.indexer.* (sibling of compressor,
|
|
# NOT nested under it).
|
|
".self_attn.compressor.indexer.q_b_proj.": ".attn.indexer.wq_b.",
|
|
".self_attn.compressor.indexer.weights_proj.": ".attn.indexer.weights_proj.",
|
|
".self_attn.compressor.indexer.kv_norm.": ".attn.indexer.k_norm.",
|
|
".self_attn.compressor.indexer.kv_proj.": ".attn.indexer.compressor.wkv.",
|
|
".self_attn.compressor.indexer.gate_proj.": ".attn.indexer.compressor.wgate.",
|
|
".self_attn.compressor.indexer.position_bias": ".attn.indexer.compressor.ape",
|
|
# === Compressor (non-indexer) NVFP4 renames ===
|
|
# Checkpoint uses kv_proj/gate_proj, model uses wkv/wgate
|
|
# (for stacking into fused_wkv_wgate).
|
|
"compressor.kv_proj.": "compressor.wkv.",
|
|
"compressor.gate_proj.": "compressor.wgate.",
|
|
"compressor.kv_norm.": "compressor.norm.",
|
|
"compressor.position_bias": "compressor.ape",
|
|
# === Attention compressor (MUST come after indexer renames
|
|
# so that remaining .self_attn.compressor. (non-indexer) keys
|
|
# become .attn.mla_attn.compressor.) ===
|
|
".self_attn.compressor.": ".attn.mla_attn.compressor.",
|
|
# === Attention projections (specific before .self_attn. → .attn.) ===
|
|
".self_attn.q_a_proj.": ".attn.wq_a.",
|
|
".self_attn.kv_proj.": ".attn.wkv.",
|
|
".self_attn.q_b_proj.": ".attn.wq_b.",
|
|
".self_attn.o_a_proj.": ".attn.wo_a.",
|
|
".self_attn.o_b_proj.": ".attn.wo_b.",
|
|
".self_attn.q_a_norm.": ".attn.q_norm.",
|
|
".self_attn.kv_norm.": ".attn.kv_norm.",
|
|
".self_attn.sinks": ".attn.attn_sink",
|
|
# Shared expert projections (specific before .mlp. → .ffn.)
|
|
".mlp.shared_experts.gate_proj.": ".ffn.shared_experts.w1.",
|
|
".mlp.shared_experts.up_proj.": ".ffn.shared_experts.w3.",
|
|
".mlp.shared_experts.down_proj.": ".ffn.shared_experts.down_proj.",
|
|
# General renames
|
|
".mlp.": ".ffn.",
|
|
".self_attn.": ".attn.",
|
|
# Layer norms (checkpoint uses input_layernorm / post_attention_layernorm,
|
|
# model uses attn_norm / ffn_norm)
|
|
"input_layernorm.": "attn_norm.",
|
|
"post_attention_layernorm.": "ffn_norm.",
|
|
# Per-layer HC params (checkpoint uses attn_hc / ffn_hc with dot,
|
|
# model uses hc_attn / hc_ffn with underscore)
|
|
".attn_hc.fn": ".hc_attn_fn",
|
|
".attn_hc.base": ".hc_attn_base",
|
|
".attn_hc.scale": ".hc_attn_scale",
|
|
".ffn_hc.fn": ".hc_ffn_fn",
|
|
".ffn_hc.base": ".hc_ffn_base",
|
|
".ffn_hc.scale": ".hc_ffn_scale",
|
|
# Top-level hc_head params (checkpoint uses hc_head.fn etc,
|
|
# model uses hc_head_fn etc)
|
|
"hc_head.fn": "hc_head_fn",
|
|
"hc_head.base": "hc_head_base",
|
|
"hc_head.scale": "hc_head_scale",
|
|
}
|
|
|
|
return WeightsMapper(
|
|
orig_to_new_prefix={
|
|
"layers.": "model.layers.",
|
|
"embed_tokens.": "model.embed_tokens.",
|
|
"norm.": "model.norm.",
|
|
"hc_head": "model.hc_head",
|
|
"mtp.": "model.mtp.",
|
|
},
|
|
orig_to_new_regex=expert_rename_regex,
|
|
orig_to_new_suffix=suffix_renames,
|
|
orig_to_new_substr=substr_renames,
|
|
)
|
|
|
|
|
|
class DeepseekV4ForCausalLM(nn.Module, SupportsPP):
|
|
model_cls = DeepseekV4Model
|
|
|
|
# Default mapper assumes the original FP4-expert checkpoint layout.
|
|
# Overridden per-instance in __init__ when expert_dtype != "fp4".
|
|
hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
|
|
# Select weight mapper based on quantization method.
|
|
# NVFP4 (modelopt_fp4) checkpoints use different key naming
|
|
# than the default MXFP4 format.
|
|
quant_config = vllm_config.quant_config
|
|
if quant_config is not None and getattr(quant_config, "get_name", lambda: None)() == "modelopt_fp4":
|
|
self.hf_to_vllm_mapper = _make_deepseek_v4_nvfp4_weights_mapper()
|
|
elif getattr(config, "expert_dtype", "fp4") != "fp4":
|
|
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp8")
|
|
else:
|
|
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4")
|
|
self.config = config
|
|
expert_dtype = getattr(config, "expert_dtype", "fp4")
|
|
if expert_dtype != "fp4":
|
|
self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype)
|
|
|
|
self.model = self.model_cls(
|
|
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
|
)
|
|
if get_pp_group().is_last_rank:
|
|
self.lm_head = ParallelLMHead(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
prefix=maybe_prefix(prefix, "lm_head"),
|
|
)
|
|
else:
|
|
self.lm_head = PPMissingLayer()
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors
|
|
)
|
|
|
|
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.embed_input_ids(input_ids)
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor | None:
|
|
logits = self.logits_processor(self.lm_head, hidden_states)
|
|
return logits
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: IntermediateTensors | None = None,
|
|
inputs_embeds: torch.Tensor | None = None,
|
|
) -> torch.Tensor | IntermediateTensors:
|
|
hidden_states = self.model(
|
|
input_ids, positions, intermediate_tensors, inputs_embeds
|
|
)
|
|
return hidden_states
|
|
|
|
def get_mtp_target_hidden_states(self) -> torch.Tensor | None:
|
|
"""Pre-hc_head residual stream buffer (max_num_batched_tokens,
|
|
hc_mult * hidden_size) for the MTP draft model. Populated by
|
|
forward(); valid after each target step."""
|
|
return getattr(self.model, "_mtp_hidden_buffer", None)
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
|
loader = AutoWeightsLoader(self, skip_substrs=["mtp."])
|
|
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
|
self.model.finalize_mega_moe_weights()
|
|
return loaded_params
|
|
|
|
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
|
return self.model.get_expert_mapping()
|