MiniMax-M2: add Eagle3 speculative decoding support (#37512)

Signed-off-by: liuchenbing <chenliumail@163.com>
Signed-off-by: liucb <liuchengbao_work@163.com>
Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
liuchenbing2026
2026-04-06 10:50:18 +08:00
committed by GitHub
parent 780ba37458
commit f6983f01de
4 changed files with 24 additions and 5 deletions

View File

@@ -1246,6 +1246,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
use_original_num_layers=True,
max_model_len=10240,
),
"Eagle3MiniMaxM2ForCausalLM": _HfExamplesInfo(
"MiniMaxAI/MiniMax-M2",
trust_remote_code=True,
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
tokenizer="MiniMaxAI/MiniMax-M2",
),
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
"mistralai/Mistral-Large-3-675B-Instruct-2512",
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",

View File

@@ -817,6 +817,7 @@ class SpeculativeConfig:
"deepseek_v3",
"kimi_k2",
"kimi_k25",
"minimax_m2",
]
if (
self.method in ("eagle3", "extract_hidden_states", "dflash")

View File

@@ -24,6 +24,7 @@
"""Inference-only MiniMaxM2 model."""
from collections.abc import Iterable
from itertools import islice
from typing import Any
import torch
@@ -59,7 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import (
)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsLoRA, SupportsPP
from .interfaces import EagleModelMixin, SupportsEagle3, SupportsLoRA, SupportsPP
from .utils import (
AutoWeightsLoader,
PPMissingLayer,
@@ -313,7 +314,7 @@ class MiniMaxM2DecoderLayer(nn.Module):
@support_torch_compile
class MiniMaxM2Model(nn.Module):
class MiniMaxM2Model(nn.Module, EagleModelMixin):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -366,7 +367,7 @@ class MiniMaxM2Model(nn.Module):
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
@@ -378,14 +379,24 @@ class MiniMaxM2Model(nn.Module):
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in self.layers[self.start_layer : self.end_layer]:
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer)
):
hidden_states, residual = layer(positions, hidden_states, residual)
self._maybe_add_hidden_state(
aux_hidden_states, idx + 1, hidden_states, residual
)
if not get_pp_group().is_last_rank:
return IntermediateTensors(
{"hidden_states": hidden_states, "residual": residual}
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
@@ -496,7 +507,7 @@ class MiniMaxM2Model(nn.Module):
return loaded_params
class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
packed_modules_mapping = {
"qkv_proj": [
"q_proj",

View File

@@ -554,6 +554,7 @@ _SPECULATIVE_DECODING_MODELS = {
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
"DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3MiniMaxM2ForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),