MiniMax-M2: add Eagle3 speculative decoding support (#37512)
Signed-off-by: liuchenbing <chenliumail@163.com> Signed-off-by: liucb <liuchengbao_work@163.com> Co-authored-by: liuchenbing <chenliumail@163.com>
This commit is contained in:
@@ -1246,6 +1246,12 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
|
||||
use_original_num_layers=True,
|
||||
max_model_len=10240,
|
||||
),
|
||||
"Eagle3MiniMaxM2ForCausalLM": _HfExamplesInfo(
|
||||
"MiniMaxAI/MiniMax-M2",
|
||||
trust_remote_code=True,
|
||||
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
|
||||
tokenizer="MiniMaxAI/MiniMax-M2",
|
||||
),
|
||||
"EagleMistralLarge3ForCausalLM": _HfExamplesInfo(
|
||||
"mistralai/Mistral-Large-3-675B-Instruct-2512",
|
||||
speculative_model="mistralai/Mistral-Large-3-675B-Instruct-2512-Eagle",
|
||||
|
||||
@@ -817,6 +817,7 @@ class SpeculativeConfig:
|
||||
"deepseek_v3",
|
||||
"kimi_k2",
|
||||
"kimi_k25",
|
||||
"minimax_m2",
|
||||
]
|
||||
if (
|
||||
self.method in ("eagle3", "extract_hidden_states", "dflash")
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
"""Inference-only MiniMaxM2 model."""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -59,7 +60,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .interfaces import EagleModelMixin, SupportsEagle3, SupportsLoRA, SupportsPP
|
||||
from .utils import (
|
||||
AutoWeightsLoader,
|
||||
PPMissingLayer,
|
||||
@@ -313,7 +314,7 @@ class MiniMaxM2DecoderLayer(nn.Module):
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MiniMaxM2Model(nn.Module):
|
||||
class MiniMaxM2Model(nn.Module, EagleModelMixin):
|
||||
fall_back_to_pt_during_load = False
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -366,7 +367,7 @@ class MiniMaxM2Model(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: IntermediateTensors | None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
@@ -378,14 +379,24 @@ class MiniMaxM2Model(nn.Module):
|
||||
hidden_states = intermediate_tensors["hidden_states"]
|
||||
residual = intermediate_tensors["residual"]
|
||||
|
||||
for layer in self.layers[self.start_layer : self.end_layer]:
|
||||
aux_hidden_states = self._maybe_add_hidden_state([], 0, hidden_states, residual)
|
||||
for idx, layer in enumerate(
|
||||
islice(self.layers, self.start_layer, self.end_layer)
|
||||
):
|
||||
hidden_states, residual = layer(positions, hidden_states, residual)
|
||||
self._maybe_add_hidden_state(
|
||||
aux_hidden_states, idx + 1, hidden_states, residual
|
||||
)
|
||||
|
||||
if not get_pp_group().is_last_rank:
|
||||
return IntermediateTensors(
|
||||
{"hidden_states": hidden_states, "residual": residual}
|
||||
)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
||||
if len(aux_hidden_states) > 0:
|
||||
return hidden_states, aux_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
|
||||
@@ -496,7 +507,7 @@ class MiniMaxM2Model(nn.Module):
|
||||
return loaded_params
|
||||
|
||||
|
||||
class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class MiniMaxM2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP, SupportsEagle3):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
|
||||
@@ -554,6 +554,7 @@ _SPECULATIVE_DECODING_MODELS = {
|
||||
"EagleMiniCPMForCausalLM": ("minicpm_eagle", "EagleMiniCPMForCausalLM"),
|
||||
"DFlashDraftModel": ("qwen3_dflash", "DFlashQwen3ForCausalLM"),
|
||||
"Eagle3LlamaForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"Eagle3MiniMaxM2ForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"LlamaForCausalLMEagle3": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"Eagle3Qwen2_5vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
"Eagle3Qwen3vlForCausalLM": ("llama_eagle3", "Eagle3LlamaForCausalLM"),
|
||||
|
||||
Reference in New Issue
Block a user