[Model] Consolidate Deepseek-MoE implementation with DeepSeek-v2 (#28101)
Signed-off-by: Kunshang Ji <kunshang.ji@intel.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Kunshang Ji <kunshang.ji@intel.com>
This commit is contained in:
@@ -58,6 +58,7 @@ from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
|
||||
from vllm.model_executor.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
@@ -104,6 +105,92 @@ elif current_platform.is_xpu():
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekAttention(nn.Module):
|
||||
"""Normal MHA implementation used by Deepseek v1."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
config: DeepseekV2Config | DeepseekV3Config,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = config.num_key_value_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class DeepseekV2MLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -163,7 +250,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
||||
|
||||
self.ep_group = get_ep_group().device_group
|
||||
self.ep_rank = get_ep_group().rank_in_group
|
||||
@@ -186,7 +273,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
if config.topk_method == "noaux_tc":
|
||||
if getattr(config, "topk_method", None) == "noaux_tc":
|
||||
self.gate.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty(config.n_routed_experts, dtype=torch.float32)
|
||||
)
|
||||
@@ -236,10 +323,10 @@ class DeepseekV2MoE(nn.Module):
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
num_expert_group=getattr(config, "n_group", 1),
|
||||
topk_group=getattr(config, "topk_group", 1),
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=config.scoring_func,
|
||||
scoring_func=getattr(config, "scoring_func", "softmax"),
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
# aiter applies routed_scaling_factor internally
|
||||
routed_scaling_factor=1.0
|
||||
@@ -999,7 +1086,19 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
# with the layer's index.
|
||||
layer_idx = int(prefix.split(sep=".")[-1])
|
||||
self.layer_idx = layer_idx
|
||||
if model_config.use_mla:
|
||||
|
||||
# verify MLA attention specific fields
|
||||
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
|
||||
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
|
||||
v_head_dim = getattr(config, "v_head_dim", 0)
|
||||
kv_lora_rank = getattr(config, "kv_lora_rank", 0)
|
||||
use_mha = config.model_type == "deepseek" or all(
|
||||
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
|
||||
)
|
||||
|
||||
if use_mha:
|
||||
attn_cls = DeepseekAttention
|
||||
elif model_config.use_mla:
|
||||
attn_cls = DeepseekV2MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
@@ -1008,11 +1107,11 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
config=config,
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
v_head_dim=config.v_head_dim,
|
||||
qk_nope_head_dim=qk_nope_head_dim,
|
||||
qk_rope_head_dim=qk_rope_head_dim,
|
||||
v_head_dim=v_head_dim,
|
||||
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
kv_lora_rank=kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
@@ -1045,7 +1144,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -1064,7 +1163,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
if hidden_states.dtype == torch.float16:
|
||||
if (
|
||||
not isinstance(self.self_attn, DeepseekAttention)
|
||||
and hidden_states.dtype == torch.float16
|
||||
):
|
||||
# Fix FP16 overflow
|
||||
# We scale both hidden_states and residual before
|
||||
# rmsnorm, and rmsnorm result would not affect by scale.
|
||||
@@ -1227,6 +1329,15 @@ class DeepseekV2ForCausalLM(
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
|
||||
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
|
||||
self.use_mha = config.model_type == "deepseek" or all(
|
||||
dim == 0 for dim in (qk_nope_head_dim, qk_rope_head_dim)
|
||||
)
|
||||
|
||||
if self.use_mha:
|
||||
self.packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
|
||||
|
||||
# `packed_modules_mapping` needs to be modified before
|
||||
# initializing DeepseekV2Model, as it is passed inplace to
|
||||
# quantization config init and may be used to select the
|
||||
@@ -1265,7 +1376,7 @@ class DeepseekV2ForCausalLM(
|
||||
def set_moe_parameters(self):
|
||||
self.expert_weights = []
|
||||
|
||||
self.num_expert_groups = self.config.n_group
|
||||
self.num_expert_groups = getattr(self.config, "n_group", 1)
|
||||
|
||||
self.moe_layers = []
|
||||
self.moe_mlp_layers = []
|
||||
@@ -1321,9 +1432,20 @@ class DeepseekV2ForCausalLM(
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
mla_params_mapping = [
|
||||
("fused_qkv_a_proj", "q_a_proj", 0),
|
||||
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
|
||||
]
|
||||
mha_params_mapping = [
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
if self.use_mha:
|
||||
stacked_params_mapping.extend(mha_params_mapping)
|
||||
else:
|
||||
stacked_params_mapping.extend(mla_params_mapping)
|
||||
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
@@ -1506,6 +1628,10 @@ class DeepseekV2ForCausalLM(
|
||||
return loaded_params
|
||||
|
||||
|
||||
class DeepseekForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
||||
pass
|
||||
|
||||
|
||||
Reference in New Issue
Block a user