Add support for Mistral Large 3 inference with Flashinfer MoE (#33174)

Signed-off-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Dimitrios Bariamis <12195802+dbari@users.noreply.github.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Dimitrios Bariamis
2026-01-31 07:48:27 +01:00
committed by GitHub
parent 73419abfae
commit f0bca83ee4
16 changed files with 1104 additions and 31 deletions

View File

@@ -27,7 +27,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
)
from vllm.platforms import current_platform
from vllm.transformers_utils.config import get_config
from vllm.triton_utils import triton
from vllm.utils.argparse_utils import FlexibleArgumentParser
@@ -482,6 +481,8 @@ class BenchmarkWorker:
block_quant_shape: list[int] = None,
use_deep_gemm: bool = False,
) -> tuple[dict[str, int], float]:
# local import to allow serialization by ray
set_random_seed(self.seed)
dtype_str = _get_config_dtype_str(
dtype, use_int8_w8a16=use_int8_w8a16, use_fp8_w8a8=use_fp8_w8a8
@@ -535,6 +536,9 @@ class BenchmarkWorker:
block_quant_shape: list[int],
use_deep_gemm: bool,
) -> dict[str, int]:
# local import to allow serialization by ray
from vllm.platforms import current_platform
best_config = None
best_time = float("inf")
if current_platform.is_rocm():
@@ -646,20 +650,28 @@ def save_configs(
f.write("\n")
def get_compressed_tensors_block_structure(config, default_value=None):
config_groups = config.get("config_groups", {})
if len(config_groups) != 1:
return default_value
group = next(iter(config_groups.values()))
weights = group.get("weights", {})
block_structure = weights.get("block_structure", default_value)
return block_structure
def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, "quantization_config", {})
if isinstance(quantization_config, dict):
return quantization_config.get("weight_block_size", default_value)
if "weight_block_size" in quantization_config:
return quantization_config["weight_block_size"]
return get_compressed_tensors_block_structure(
quantization_config, default_value
)
return default_value
def main(args: argparse.Namespace):
print(args)
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
def get_model_params(config):
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
@@ -677,6 +689,7 @@ def main(args: argparse.Namespace):
"Glm4MoeForCausalLM",
"Glm4MoeLiteForCausalLM",
"NemotronHForCausalLM",
"MistralLarge3ForCausalLM",
):
E = config.n_routed_experts
topk = config.num_experts_per_tok
@@ -697,16 +710,20 @@ def main(args: argparse.Namespace):
topk = text_config.num_experts_per_tok
intermediate_size = text_config.moe_intermediate_size
hidden_size = text_config.hidden_size
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
elif config.architectures[0] == "HunYuanMoEV1ForCausalLM":
E = config.num_experts
topk = config.moe_topk[0]
intermediate_size = config.moe_intermediate_size[0]
hidden_size = config.hidden_size
elif config.architectures[0] in ["Qwen3OmniMoeForConditionalGeneration"]:
elif config.architectures[0] == "Qwen3OmniMoeForConditionalGeneration":
E = config.thinker_config.text_config.num_experts
topk = config.thinker_config.text_config.num_experts_per_tok
intermediate_size = config.thinker_config.text_config.moe_intermediate_size
hidden_size = config.thinker_config.text_config.hidden_size
elif config.architectures[0] == "PixtralForConditionalGeneration":
# Pixtral can contain different LLM architectures,
# recurse to get their parameters
return get_model_params(config.get_text_config())
else:
# Support for llama4
config = config.get_text_config()
@@ -715,6 +732,16 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
hidden_size = config.hidden_size
return E, topk, intermediate_size, hidden_size
def main(args: argparse.Namespace):
print(args)
config = get_config(model=args.model, trust_remote_code=args.trust_remote_code)
if args.model_prefix:
config = getattr(config, args.model_prefix)
E, topk, intermediate_size, hidden_size = get_model_params(config)
enable_ep = bool(args.enable_expert_parallel)
if enable_ep:
ensure_divisibility(E, args.tp_size, "Number of experts")