diff --git a/tests/models/registry.py b/tests/models/registry.py index 895dc4579..f1f80e639 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1246,6 +1246,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { use_original_num_layers=True, max_model_len=10240, ), + "Eagle3MiniMaxM2ForCausalLM": _HfExamplesInfo( + "MiniMaxAI/MiniMax-M2", + trust_remote_code=True, + speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", + tokenizer="MiniMaxAI/MiniMax-M2", + ), "EagleMistralLarge3ForCausalLM": _HfExamplesInfo( "mistralai/Mistral-Large-3-675B-Instruct-2512", speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle", diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index f1fda9afd..0e74501dd 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -817,6 +817,7 @@ class SpeculativeConfig: "deepseek_v3", "kimi_k2", "kimi_k25", + "minimax_m2", ] if ( self.method in ("eagle3", "extract_hidden_states", "dflash") diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 0f43bc0cd..f10452c57 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -24,6 +24,7 @@ """Inference-only MiniMaxM2 model.""" from collections.abc import Iterable +from itertools import islice from typing import Any import torch @@ -59,7 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import ( ) from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA, SupportsPP +from .interfaces import EagleModelMixin, SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, @@ -313,7 +314,7 @@ class MiniMaxM2DecoderLayer(nn.Module): @support_torch_compile -class MiniMaxM2Model(nn.Module): +class MiniMaxM2Model(nn.Module, EagleModelMixin): fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): @@ -366,7 +367,7 @@ class MiniMaxM2Model(nn.Module): positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, - ) -> torch.Tensor | IntermediateTensors: + ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -378,14 +379,24 @@ class MiniMaxM2Model(nn.Module): hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in self.layers[self.start_layer : self.end_layer]: + aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual) + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer) + ): hidden_states, residual = layer(positions, hidden_states, residual) + self._maybe_add_hidden_state( + aux_hidden_states, idx + 1, hidden_states, residual + ) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) + + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states + return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: @@ -496,7 +507,7 @@ class MiniMaxM2Model(nn.Module): return loaded_params -class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): +class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 1901381cb..4b354add3 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -554,6 +554,7 @@ _SPECULATIVE_DECODING_MODELS = { "EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"), "DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"), "Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), + "Eagle3MiniMaxM2ForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"), "Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),