[Misc] Remove duplicated DeepSeek V2/V3 model definition (#12793)

This commit is contained in:
Michael Goin
2025-02-06 02:16:20 -05:00
committed by GitHub
parent 1a6fcad4c9
commit 449d1bce02
4 changed files with 36 additions and 821 deletions

View File

@@ -21,7 +21,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV2 model."""
"""Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
@@ -115,23 +115,32 @@ class DeepseekV2MoE(nn.Module):
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
"Only silu is supported for now.")
self.experts = FusedMoE(num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts")
self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate")
if config.topk_method == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts))
else:
self.gate.e_score_correction_bias = None
self.experts = FusedMoE(
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
prefix=f"{prefix}.experts",
scoring_func=config.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias)
if config.n_shared_experts is not None:
intermediate_size = (config.moe_intermediate_size *
config.n_shared_experts)
@@ -732,6 +741,15 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
# TODO(simon): support nextn predict layers
if hasattr(self.config, "num_nextn_predict_layers"
) and self.config.num_nextn_predict_layers > 0:
assert self.config.num_nextn_predict_layers == 1
layer_idx = self.config.num_hidden_layers
if name.startswith(f"model.layers.{layer_idx}"):
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
@@ -793,3 +811,7 @@ class DeepseekV2ForCausalLM(nn.Module, SupportsPP):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
pass