[Model] Consolidate Deepseek-MoE implementation with DeepSeek-v2 (#28101)

Signed-off-by: Kunshang Ji <kunshang.ji@intel.com>
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
Isotr0py
2025-11-08 13:01:27 +08:00
committed by GitHub
parent 70af44fd10
commit 934a9c3b79
6 changed files with 144 additions and 548 deletions

View File

@@ -58,6 +58,7 @@ from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
@@ -104,6 +105,92 @@ elif current_platform.is_xpu():
logger = init_logger(__name__)
class DeepseekAttention(nn.Module):
"""Normal MHA implementation used by Deepseek v1."""
def __init__(
self,
vllm_config: VllmConfig,
config: DeepseekV2Config | DeepseekV3Config,
hidden_size: int,
num_heads: int,
rope_theta: float = 10000,
rope_scaling: dict[str, Any] | None = None,
max_position_embeddings: int = 8192,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
**kwargs,
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
self.total_num_kv_heads = config.num_key_value_heads
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
self.head_dim = hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=False,
quant_config=quant_config,
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=False,
quant_config=quant_config,
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
class DeepseekV2MLP(nn.Module):
def __init__(
self,
@@ -163,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.routed_scaling_factor = config.routed_scaling_factor
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.ep_group = get_ep_group().device_group
self.ep_rank = get_ep_group().rank_in_group
@@ -186,7 +273,7 @@ class DeepseekV2MoE(nn.Module):
quant_config=None,
prefix=f"{prefix}.gate",
)
if config.topk_method == "noaux_tc":
if getattr(config, "topk_method", None) == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32)
)
@@ -236,10 +323,10 @@ class DeepseekV2MoE(nn.Module):
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
num_expert_group=getattr(config, "n_group", 1),
topk_group=getattr(config, "topk_group", 1),
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
scoring_func=getattr(config, "scoring_func", "softmax"),
# we do scaling outside, set factor to 1.0 to avoid double mul
# aiter applies routed_scaling_factor internally
routed_scaling_factor=1.0
@@ -999,7 +1086,19 @@ class DeepseekV2DecoderLayer(nn.Module):
# with the layer's index.
layer_idx = int(prefix.split(sep=".")[-1])
self.layer_idx = layer_idx
if model_config.use_mla:
# verify MLA attention specific fields
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
v_head_dim = getattr(config, "v_head_dim", 0)
kv_lora_rank = getattr(config, "kv_lora_rank", 0)
use_mha = config.model_type == "deepseek" or all(
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
)
if use_mha:
attn_cls = DeepseekAttention
elif model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
@@ -1008,11 +1107,11 @@ class DeepseekV2DecoderLayer(nn.Module):
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
kv_lora_rank=kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
@@ -1045,7 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.routed_scaling_factor = config.routed_scaling_factor
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
def forward(
self,
@@ -1064,7 +1163,10 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states=hidden_states,
)
if hidden_states.dtype == torch.float16:
if (
not isinstance(self.self_attn, DeepseekAttention)
and hidden_states.dtype == torch.float16
):
# Fix FP16 overflow
# We scale both hidden_states and residual before
# rmsnorm, and rmsnorm result would not affect by scale.
@@ -1227,6 +1329,15 @@ class DeepseekV2ForCausalLM(
self.config = config
self.quant_config = quant_config
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
self.use_mha = config.model_type == "deepseek" or all(
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
)
if self.use_mha:
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
@@ -1265,7 +1376,7 @@ class DeepseekV2ForCausalLM(
def set_moe_parameters(self):
self.expert_weights = []
self.num_expert_groups = self.config.n_group
self.num_expert_groups = getattr(self.config, "n_group", 1)
self.moe_layers = []
self.moe_mlp_layers = []
@@ -1321,9 +1432,20 @@ class DeepseekV2ForCausalLM(
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
mla_params_mapping = [
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
if self.use_mha:
stacked_params_mapping.extend(mha_params_mapping)
else:
stacked_params_mapping.extend(mla_params_mapping)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
@@ -1506,6 +1628,10 @@ class DeepseekV2ForCausalLM(
return loaded_params
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
pass
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass