Add support for Mistral Large 3 inference with Flashinfer MoE (#33174)
Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
committed by
GitHub
parent
73419abfae
commit
f0bca83ee4
@@ -295,6 +295,14 @@ class DeepseekV2MoE(nn.Module):
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
|
||||
n_group = getattr(config, "n_group", 1)
|
||||
topk_group = getattr(config, "topk_group", 1)
|
||||
use_grouped_topk = True
|
||||
if (n_group, topk_group) == (1, 1):
|
||||
n_group = None
|
||||
topk_group = None
|
||||
use_grouped_topk = False
|
||||
|
||||
self.experts = SharedFusedMoE(
|
||||
shared_experts=self.shared_experts,
|
||||
gate=self.gate,
|
||||
@@ -305,9 +313,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
reduce_results=False,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=getattr(config, "n_group", 1),
|
||||
topk_group=getattr(config, "topk_group", 1),
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=n_group,
|
||||
topk_group=topk_group,
|
||||
prefix=f"{prefix}.experts",
|
||||
scoring_func=getattr(config, "scoring_func", "softmax"),
|
||||
# we do scaling outside, set factor to 1.0 to avoid double mul
|
||||
|
||||
Reference in New Issue
Block a user