diff --git a/vllm/model_executor/models/mistral_large_3_eagle.py b/vllm/model_executor/models/mistral_large_3_eagle.py index 830f210e7..4567f24fd 100644 --- a/vllm/model_executor/models/mistral_large_3_eagle.py +++ b/vllm/model_executor/models/mistral_large_3_eagle.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy from collections.abc import Iterable from functools import partial @@ -33,7 +34,9 @@ class EagleMistralLarge3Model(DeepseekV2Model): ): nn.Module.__init__(self) - config = vllm_config.model_config.hf_config + config = copy.deepcopy(vllm_config.model_config.hf_config) + config.first_k_dense_replace += start_layer_id + quant_config = vllm_config.quant_config self.config = config self.vllm_config = vllm_config @@ -53,6 +56,7 @@ class EagleMistralLarge3Model(DeepseekV2Model): DeepseekV2DecoderLayer( vllm_config=vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=config, ) for i in range(self.config.num_hidden_layers) ] diff --git a/vllm/transformers_utils/configs/mistral.py b/vllm/transformers_utils/configs/mistral.py index aea990b07..1e1e49f7c 100644 --- a/vllm/transformers_utils/configs/mistral.py +++ b/vllm/transformers_utils/configs/mistral.py @@ -19,6 +19,10 @@ def adapt_config_dict( if bool(config_dict.get("quantization")): config_dict = _remap_mistral_quantization_args(config_dict) + is_mla = bool(config_dict.get("qk_nope_head_dim")) + if is_mla: + config_dict = _remap_mistral_mla_args(config_dict) + is_moe = bool(config_dict.get("moe")) is_mistral_large_3 = ( is_moe and (config_dict["moe"].get("num_shared_experts") or 0) > 0 @@ -291,3 +295,22 @@ def _remap_moe_args(config: dict) -> dict: config["scoring_func"] = "softmax" return config + + +def _remap_mistral_mla_args(config: dict) -> dict: + if not config.get("moe"): + moe = { + "num_experts": 1, + "first_k_dense_replace": config.get("num_hidden_layers"), + "route_every_n": 1, + "num_shared_experts": 1, + "expert_hidden_dim": config.get("intermediate_size"), + "num_experts_per_tok": 1, + "routed_scale": 1.0, + "renorm_strategy": "WEIGHTS", + "use_load_balancing_bias": False, + "num_expert_groups": 1, + "num_expert_groups_per_tok": 1, + } + config["moe"] = moe + return config