Kimi k2.5 MLA based eagle3 (#36361)

Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
Signed-off-by: Jhao-Ting Chen <jhaotingc@nvidia.com>
Co-authored-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
Jhao-Ting Chen
2026-03-11 08:36:11 -07:00
committed by GitHub
parent d5816c8c2f
commit 5573894737
8 changed files with 499 additions and 8 deletions

View File

@@ -1137,6 +1137,18 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = {
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
),
"Eagle3DeepseekV2ForCausalLM": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
speculative_model="AQ-MedAI/Kimi-K25-eagle3",
tokenizer="moonshotai/Kimi-K2.5",
),
"Eagle3DeepseekV3ForCausalLM": _HfExamplesInfo(
"moonshotai/Kimi-K2.5",
trust_remote_code=True,
speculative_model="AQ-MedAI/Kimi-K25-eagle3",
tokenizer="moonshotai/Kimi-K2.5",
),
"Eagle3LlamaForCausalLM": _HfExamplesInfo(
"meta-llama/Llama-3.1-8B-Instruct",
trust_remote_code=True,

View File

@@ -779,6 +779,10 @@ class SpeculativeConfig:
"hunyuan_v1_dense",
"afmoe",
"nemotron_h",
"deepseek_v2",
"deepseek_v3",
"kimi_k2",
"kimi_k25",
]
if (
self.method in ("eagle3", "extract_hidden_states")

View File

@@ -0,0 +1,419 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Eagle3 speculative decoding model for DeepseekV2/V3 with MLP (no MoE)."""
import copy
from collections.abc import Iterable
import torch
import torch.nn as nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
maybe_remap_kv_scale_name,
)
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2ForCausalLM,
DeepseekV2MLAAttention,
DeepseekV2MLP,
)
from vllm.multimodal.inputs import NestedTensors
from .utils import (
AutoWeightsLoader,
get_draft_quant_config,
maybe_prefix,
process_eagle_weight,
)
logger = init_logger(__name__)
class DeepseekV2Eagle3DecoderLayer(nn.Module):
"""
Eagle3 decoder layer for Deepseek that:
1. Always uses MLP (not MoE)
2. First layer accepts concatenated embeds + hidden_states
"""
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
config: DeepseekV2Config | DeepseekV3Config | None = None,
layer_idx: int = 0,
) -> None:
super().__init__()
if config is None:
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = get_draft_quant_config(vllm_config)
self.hidden_size = config.hidden_size
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.layer_idx = layer_idx
# MLA attention parameters
qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0)
qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0)
v_head_dim = getattr(config, "v_head_dim", 0)
kv_lora_rank = getattr(config, "kv_lora_rank", 0)
config = copy.copy(config)
if rope_scaling:
rope_params = rope_scaling.copy()
rope_params["rope_type"] = "deepseek_yarn"
else:
rope_params = {"rope_type": "default"}
config.rope_parameters = rope_params
self.self_attn = DeepseekV2MLAAttention(
vllm_config=vllm_config,
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=qk_nope_head_dim,
qk_rope_head_dim=qk_rope_head_dim,
v_head_dim=v_head_dim,
q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=kv_lora_rank,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
input_size=2 * self.hidden_size if layer_idx == 0 else self.hidden_size,
)
# Always use MLP (not MoE) for Eagle3
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if getattr(config, "norm_before_residual", False):
self._residual_norm = self._norm_before_residual
else:
self._residual_norm = self._norm_after_residual
def _norm_before_residual(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
hidden_states = self.hidden_norm(hidden_states)
residual = hidden_states
return hidden_states, residual
def _norm_after_residual(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.hidden_norm(hidden_states)
return hidden_states, residual
def forward(
self,
positions: torch.Tensor,
embeds: torch.Tensor,
hidden_states: torch.Tensor,
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
if self.layer_idx == 0:
# First layer: concatenate embeds with hidden_states
embeds = self.input_layernorm(embeds)
hidden_states, residual = self._residual_norm(hidden_states=hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
else:
# Subsequent layers: process hidden_states and residuals only
hidden_states, residual = self.input_layernorm(hidden_states, residual)
# Self Attention
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
llama_4_scaling=None,
)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
# Fully Connected (MLP, not MoE)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
@support_torch_compile
class DeepseekV2Eagle3Model(nn.Module):
def __init__(
self,
*,
vllm_config: VllmConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = vllm_config.speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
# Get drafter's quantization config
self.quant_config = get_draft_quant_config(vllm_config)
current_vllm_config = get_current_vllm_config()
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.layers = nn.ModuleList(
[
DeepseekV2Eagle3DecoderLayer(
current_vllm_config,
prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"),
config=self.config,
layer_idx=layer_idx,
)
for layer_idx in range(self.config.num_hidden_layers)
]
)
# fc layer for combining auxiliary hidden states (3x hidden size input)
if hasattr(self.config, "target_hidden_size"):
fc_input_size = self.config.target_hidden_size * 3
else:
fc_input_size = self.config.hidden_size * 3
self.fc = ReplicatedLinear(
input_size=fc_input_size,
output_size=self.config.hidden_size,
bias=False,
params_dtype=vllm_config.model_config.dtype,
quant_config=self.quant_config,
prefix=maybe_prefix(prefix, "fc"),
return_bias=False,
)
self.norm = RMSNorm(
self.config.hidden_size,
eps=self.config.rms_norm_eps,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if input_embeds is None:
input_embeds = self.embed_input_ids(input_ids)
assert hidden_states.shape[-1] == input_embeds.shape[-1]
residual = None
for layer in self.layers:
hidden_states, residual = layer(
positions=positions,
embeds=input_embeds,
hidden_states=hidden_states,
residual=residual,
)
hidden_states, hidden_prenorm = self.norm(hidden_states, residual)
return hidden_states, hidden_prenorm
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
(".fused_qkv_a_proj", ".q_a_proj", 0),
(".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "midlayer." in name:
name = name.replace("midlayer.", "layers.0.")
# Handle kv cache quantization scales
if self.quant_config is not None and (
scale_name := self.quant_config.get_cache_scale(name)
):
param = params_dict[scale_name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
loaded_weight = (
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
)
weight_loader(param, loaded_weight)
loaded_params.add(scale_name)
continue
# Remapping the name FP8 kv-scale
if "scale" in name:
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class Eagle3DeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
"""Eagle3 speculative decoding model for DeepseekV2/V3."""
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
self.config = vllm_config.speculative_config.draft_model_config.hf_config
# Ensure draft_vocab_size is set
if getattr(self.config, "draft_vocab_size", None) is None:
base_vocab_size = getattr(self.config, "vocab_size", None)
self.config.draft_vocab_size = base_vocab_size
target_layer_num = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config
)
# Store target layer count in draft config
self.config.target_layer_count = target_layer_num
self.model = DeepseekV2Eagle3Model(
vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num
)
logit_scale = getattr(self.config, "logit_scale", 1.0)
self.lm_head = ParallelLMHead(
self.config.draft_vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
self.config.draft_vocab_size, scale=logit_scale
)
self.draft_id_to_target_id = nn.Parameter(
torch.zeros(self.config.draft_vocab_size, dtype=torch.long),
requires_grad=False,
)
def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: NestedTensors | None = None,
is_multimodal: torch.Tensor | None = None,
) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
return self.model(input_ids, positions, hidden_states, inputs_embeds)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
if self.draft_id_to_target_id is None:
assert logits.shape[1] == self.config.vocab_size, (
"Expected logits to have shape "
f"(*, {self.config.vocab_size}), but got {logits.shape}"
)
return logits
base = torch.arange(self.config.draft_vocab_size, device=logits.device)
targets = base + self.draft_id_to_target_id
logits_new = logits.new_full(
(
logits.shape[0],
self.config.vocab_size,
),
float("-inf"),
)
logits_new[:, targets] = logits
return logits_new
def combine_hidden_states(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# Combine multiple auxiliary hidden states returned by Eagle3
return self.model.fc(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
model_weights = {}
includes_draft_id_mapping = False
includes_embed_tokens = False
for name, loaded_weight in weights:
if "t2d" in name:
continue
if "d2t" in name:
name = name.replace("d2t", "draft_id_to_target_id")
includes_draft_id_mapping = True
elif "lm_head" not in name:
name = "model." + name
if "embed_tokens" in name:
includes_embed_tokens = True
model_weights[name] = loaded_weight
process_eagle_weight(self, name)
skip_substrs = []
if not includes_draft_id_mapping:
skip_substrs.append("draft_id_to_target_id")
if not includes_embed_tokens:
skip_substrs.append("embed_tokens")
loader = AutoWeightsLoader(
self,
skip_prefixes=None,
skip_substrs=skip_substrs,
)
loader.load_weights(model_weights.items())
# Aliases for compatibility
Eagle3DeepseekV3ForCausalLM = Eagle3DeepseekV2ForCausalLM

View File

@@ -82,7 +82,13 @@ from vllm.v1.attention.backends.mla.indexer import (
)
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP
from .interfaces import (
MixtureOfExperts,
SupportsEagle,
SupportsEagle3,
SupportsLoRA,
SupportsPP,
)
from .utils import (
PPMissingLayer,
is_pp_missing_parameter,
@@ -828,6 +834,7 @@ class DeepseekV2MLAAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
topk_indices_buffer: torch.Tensor | None = None,
input_size: int | None = None,
) -> None:
super().__init__()
self.hidden_size = hidden_size
@@ -847,16 +854,20 @@ class DeepseekV2MLAAttention(nn.Module):
self.scaling = self.qk_head_dim**-0.5
self.max_position_embeddings = max_position_embeddings
# Use input_size for projection input dimensions if provided,
# otherwise default to hidden_size (used in Eagle3 Deepseek with MLA)
proj_input_size = input_size if input_size is not None else self.hidden_size
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProjLinear(
self.hidden_size,
proj_input_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
)
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
proj_input_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
@@ -874,7 +885,7 @@ class DeepseekV2MLAAttention(nn.Module):
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
proj_input_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
@@ -1170,6 +1181,8 @@ class DeepseekV2Model(nn.Module):
["hidden_states", "residual"], config.hidden_size
)
self.aux_hidden_state_layers = tuple[int, ...]()
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
@@ -1205,7 +1218,13 @@ class DeepseekV2Model(nn.Module):
else:
llama_4_scaling = None
for layer in islice(self.layers, self.start_layer, self.end_layer):
aux_hidden_states = []
for idx, layer in enumerate(
islice(self.layers, self.start_layer, self.end_layer),
start=self.start_layer,
):
if idx in self.aux_hidden_state_layers:
aux_hidden_states.append(hidden_states + residual)
hidden_states, residual = layer(
positions, hidden_states, residual, llama_4_scaling
)
@@ -1216,6 +1235,8 @@ class DeepseekV2Model(nn.Module):
)
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) > 0:
return hidden_states, aux_hidden_states
return hidden_states
@@ -1261,7 +1282,12 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts):
class DeepseekV2ForCausalLM(
nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle
nn.Module,
SupportsPP,
DeepseekV2MixtureOfExperts,
SupportsLoRA,
SupportsEagle,
SupportsEagle3,
):
packed_modules_mapping = {
"gate_up_proj": ["gate_proj", "up_proj"],
@@ -1340,6 +1366,13 @@ class DeepseekV2ForCausalLM(
self.extract_moe_parameters(example_moe)
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)

View File

@@ -28,6 +28,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
CompressedTensorsConfig,
)
from vllm.model_executor.models.interfaces import (
SupportsEagle,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
@@ -311,7 +313,12 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
dummy_inputs=KimiK25DummyInputsBuilder,
)
class KimiK25ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant
nn.Module,
SupportsMultiModal,
SupportsPP,
SupportsQuant,
SupportsEagle,
SupportsEagle3,
):
"""Kimi-K2.5 model for conditional generation.
@@ -480,6 +487,12 @@ class KimiK25ForConditionalGeneration(
logits = self.language_model.compute_logits(hidden_states)
return logits
def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.language_model.set_aux_hidden_state_layers(layers)
def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
return self.language_model.get_eagle3_aux_hidden_state_layers()
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

View File

@@ -551,6 +551,8 @@ _SPECULATIVE_DECODING_MODELS = {
"mistral_large_3_eagle",
"EagleMistralLarge3ForCausalLM",
),
"Eagle3DeepseekV2ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
"Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),

View File

@@ -87,6 +87,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
funaudiochat="FunAudioChatConfig",
hunyuan_vl="HunYuanVLConfig",
isaac="IsaacConfig",
kimi_k2="DeepseekV3Config", # Kimi K2 uses same architecture as DeepSeek V3
kimi_linear="KimiLinearConfig",
kimi_vl="KimiVLConfig",
kimi_k25="KimiK25Config",

View File

@@ -20,6 +20,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -403,7 +404,9 @@ class SpecDecodeBaseProposer:
batch_size = common_attn_metadata.batch_size()
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
assert isinstance(
self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM)
)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states
)
@@ -1278,6 +1281,10 @@ class SpecDecodeBaseProposer:
self.model.config.image_token_index = (
target_model.config.vision_config.image_token_id
)
elif self.get_model_name(target_model) == "KimiK25ForConditionalGeneration":
self.model.config.image_token_index = (
target_model.config.media_placeholder_token_id
)
else:
self.model.config.image_token_index = (
target_model.config.image_token_index