Add support for Mistral Large 3 inference with Flashinfer MoE (#33174)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Dimitrios Bariamis
2026-01-31 07:48:27 +01:00
committed by GitHub
parent 73419abfae
commit f0bca83ee4
16 changed files with 1104 additions and 31 deletions

View File

@@ -295,6 +295,14 @@ class DeepseekV2MoE(nn.Module):
prefix=f"{prefix}.shared_experts",
)
n_group = getattr(config, "n_group", 1)
topk_group = getattr(config, "topk_group", 1)
use_grouped_topk = True
if (n_group, topk_group) == (1, 1):
n_group = None
topk_group = None
use_grouped_topk = False
self.experts = SharedFusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
@@ -305,9 +313,9 @@ class DeepseekV2MoE(nn.Module):
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
use_grouped_topk=True,
num_expert_group=getattr(config, "n_group", 1),
topk_group=getattr(config, "topk_group", 1),
use_grouped_topk=use_grouped_topk,
num_expert_group=n_group,
topk_group=topk_group,
prefix=f"{prefix}.experts",
scoring_func=getattr(config, "scoring_func", "softmax"),
# we do scaling outside, set factor to 1.0 to avoid double mul