diff --git a/tests/models/registry.py b/tests/models/registry.py index 17931079c..9b533d8f4 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -1137,6 +1137,18 @@ _SPECULATIVE_DECODING_EXAMPLE_MODELS = { speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B", tokenizer="meta-llama/Meta-Llama-3-8B-Instruct", ), + "Eagle3DeepseekV2ForCausalLM": _HfExamplesInfo( + "moonshotai/Kimi-K2.5", + trust_remote_code=True, + speculative_model="AQ-MedAI/Kimi-K25-eagle3", + tokenizer="moonshotai/Kimi-K2.5", + ), + "Eagle3DeepseekV3ForCausalLM": _HfExamplesInfo( + "moonshotai/Kimi-K2.5", + trust_remote_code=True, + speculative_model="AQ-MedAI/Kimi-K25-eagle3", + tokenizer="moonshotai/Kimi-K2.5", + ), "Eagle3LlamaForCausalLM": _HfExamplesInfo( "meta-llama/Llama-3.1-8B-Instruct", trust_remote_code=True, diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 27b5188eb..ee94ea879 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -779,6 +779,10 @@ class SpeculativeConfig: "hunyuan_v1_dense", "afmoe", "nemotron_h", + "deepseek_v2", + "deepseek_v3", + "kimi_k2", + "kimi_k25", ] if ( self.method in ("eagle3", "extract_hidden_states") diff --git a/vllm/model_executor/models/deepseek_eagle3.py b/vllm/model_executor/models/deepseek_eagle3.py new file mode 100644 index 000000000..640ba8991 --- /dev/null +++ b/vllm/model_executor/models/deepseek_eagle3.py @@ -0,0 +1,419 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +"""Eagle3 speculative decoding model for DeepseekV2/V3 with MLP (no MoE).""" + +import copy +from collections.abc import Iterable + +import torch +import torch.nn as nn +from transformers import DeepseekV2Config, DeepseekV3Config + +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig, get_current_vllm_config +from vllm.logger import init_logger +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, + VocabParallelEmbedding, +) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, + maybe_remap_kv_scale_name, +) +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2ForCausalLM, + DeepseekV2MLAAttention, + DeepseekV2MLP, +) +from vllm.multimodal.inputs import NestedTensors + +from .utils import ( + AutoWeightsLoader, + get_draft_quant_config, + maybe_prefix, + process_eagle_weight, +) + +logger = init_logger(__name__) + + +class DeepseekV2Eagle3DecoderLayer(nn.Module): + """ + Eagle3 decoder layer for Deepseek that: + 1. Always uses MLP (not MoE) + 2. First layer accepts concatenated embeds + hidden_states + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str, + config: DeepseekV2Config | DeepseekV3Config | None = None, + layer_idx: int = 0, + ) -> None: + super().__init__() + + if config is None: + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = get_draft_quant_config(vllm_config) + + self.hidden_size = config.hidden_size + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + + self.layer_idx = layer_idx + + # MLA attention parameters + qk_nope_head_dim = getattr(config, "qk_nope_head_dim", 0) + qk_rope_head_dim = getattr(config, "qk_rope_head_dim", 0) + v_head_dim = getattr(config, "v_head_dim", 0) + kv_lora_rank = getattr(config, "kv_lora_rank", 0) + config = copy.copy(config) + if rope_scaling: + rope_params = rope_scaling.copy() + rope_params["rope_type"] = "deepseek_yarn" + else: + rope_params = {"rope_type": "default"} + config.rope_parameters = rope_params + self.self_attn = DeepseekV2MLAAttention( + vllm_config=vllm_config, + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + q_lora_rank=config.q_lora_rank if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=kv_lora_rank, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + input_size=2 * self.hidden_size if layer_idx == 0 else self.hidden_size, + ) + + # Always use MLP (not MoE) for Eagle3 + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + self.hidden_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + if getattr(config, "norm_before_residual", False): + self._residual_norm = self._norm_before_residual + else: + self._residual_norm = self._norm_after_residual + + def _norm_before_residual( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states = self.hidden_norm(hidden_states) + residual = hidden_states + return hidden_states, residual + + def _norm_after_residual( + self, hidden_states: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + residual = hidden_states + hidden_states = self.hidden_norm(hidden_states) + return hidden_states, residual + + def forward( + self, + positions: torch.Tensor, + embeds: torch.Tensor, + hidden_states: torch.Tensor, + residual: torch.Tensor | None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.layer_idx == 0: + # First layer: concatenate embeds with hidden_states + embeds = self.input_layernorm(embeds) + hidden_states, residual = self._residual_norm(hidden_states=hidden_states) + hidden_states = torch.cat([embeds, hidden_states], dim=-1) + else: + # Subsequent layers: process hidden_states and residuals only + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + # Self Attention + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + llama_4_scaling=None, + ) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + + # Fully Connected (MLP, not MoE) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +@support_torch_compile +class DeepseekV2Eagle3Model(nn.Module): + def __init__( + self, + *, + vllm_config: VllmConfig, + start_layer_id: int = 0, + prefix: str = "", + ) -> None: + super().__init__() + self.config = vllm_config.speculative_config.draft_model_config.hf_config + self.vocab_size = self.config.vocab_size + + # Get drafter's quantization config + self.quant_config = get_draft_quant_config(vllm_config) + + current_vllm_config = get_current_vllm_config() + + self.embed_tokens = VocabParallelEmbedding( + self.config.vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "embed_tokens"), + ) + + self.layers = nn.ModuleList( + [ + DeepseekV2Eagle3DecoderLayer( + current_vllm_config, + prefix=maybe_prefix(prefix, f"layers.{layer_idx + start_layer_id}"), + config=self.config, + layer_idx=layer_idx, + ) + for layer_idx in range(self.config.num_hidden_layers) + ] + ) + + # fc layer for combining auxiliary hidden states (3x hidden size input) + if hasattr(self.config, "target_hidden_size"): + fc_input_size = self.config.target_hidden_size * 3 + else: + fc_input_size = self.config.hidden_size * 3 + + self.fc = ReplicatedLinear( + input_size=fc_input_size, + output_size=self.config.hidden_size, + bias=False, + params_dtype=vllm_config.model_config.dtype, + quant_config=self.quant_config, + prefix=maybe_prefix(prefix, "fc"), + return_bias=False, + ) + + self.norm = RMSNorm( + self.config.hidden_size, + eps=self.config.rms_norm_eps, + ) + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + input_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + if input_embeds is None: + input_embeds = self.embed_input_ids(input_ids) + assert hidden_states.shape[-1] == input_embeds.shape[-1] + + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + embeds=input_embeds, + hidden_states=hidden_states, + residual=residual, + ) + hidden_states, hidden_prenorm = self.norm(hidden_states, residual) + return hidden_states, hidden_prenorm + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + (".fused_qkv_a_proj", ".q_a_proj", 0), + (".fused_qkv_a_proj", ".kv_a_proj_with_mqa", 1), + ] + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + + for name, loaded_weight in weights: + if "midlayer." in name: + name = name.replace("midlayer.", "layers.0.") + + # Handle kv cache quantization scales + if self.quant_config is not None and ( + scale_name := self.quant_config.get_cache_scale(name) + ): + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + loaded_weight = ( + loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0] + ) + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + + # Remapping the name FP8 kv-scale + if "scale" in name: + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + if name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + + return loaded_params + + +class Eagle3DeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + """Eagle3 speculative decoding model for DeepseekV2/V3.""" + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.speculative_config.draft_model_config.hf_config + + # Ensure draft_vocab_size is set + if getattr(self.config, "draft_vocab_size", None) is None: + base_vocab_size = getattr(self.config, "vocab_size", None) + self.config.draft_vocab_size = base_vocab_size + + target_layer_num = vllm_config.model_config.get_num_layers( + vllm_config.parallel_config + ) + + # Store target layer count in draft config + self.config.target_layer_count = target_layer_num + + self.model = DeepseekV2Eagle3Model( + vllm_config=vllm_config, prefix="model", start_layer_id=target_layer_num + ) + + logit_scale = getattr(self.config, "logit_scale", 1.0) + self.lm_head = ParallelLMHead( + self.config.draft_vocab_size, + self.config.hidden_size, + prefix=maybe_prefix(prefix, "lm_head"), + ) + self.logits_processor = LogitsProcessor( + self.config.draft_vocab_size, scale=logit_scale + ) + self.draft_id_to_target_id = nn.Parameter( + torch.zeros(self.config.draft_vocab_size, dtype=torch.long), + requires_grad=False, + ) + + def embed_input_ids( + self, + input_ids: torch.Tensor, + multimodal_embeddings: NestedTensors | None = None, + is_multimodal: torch.Tensor | None = None, + ) -> torch.Tensor: + return self.model.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return self.model(input_ids, positions, hidden_states, inputs_embeds) + + def compute_logits( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor | None: + logits = self.logits_processor(self.lm_head, hidden_states) + if self.draft_id_to_target_id is None: + assert logits.shape[1] == self.config.vocab_size, ( + "Expected logits to have shape " + f"(*, {self.config.vocab_size}), but got {logits.shape}" + ) + return logits + + base = torch.arange(self.config.draft_vocab_size, device=logits.device) + targets = base + self.draft_id_to_target_id + logits_new = logits.new_full( + ( + logits.shape[0], + self.config.vocab_size, + ), + float("-inf"), + ) + logits_new[:, targets] = logits + return logits_new + + def combine_hidden_states( + self, + hidden_states: torch.Tensor, + ) -> torch.Tensor: + # Combine multiple auxiliary hidden states returned by Eagle3 + return self.model.fc(hidden_states) + + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + model_weights = {} + includes_draft_id_mapping = False + includes_embed_tokens = False + + for name, loaded_weight in weights: + if "t2d" in name: + continue + if "d2t" in name: + name = name.replace("d2t", "draft_id_to_target_id") + includes_draft_id_mapping = True + elif "lm_head" not in name: + name = "model." + name + if "embed_tokens" in name: + includes_embed_tokens = True + model_weights[name] = loaded_weight + process_eagle_weight(self, name) + + skip_substrs = [] + if not includes_draft_id_mapping: + skip_substrs.append("draft_id_to_target_id") + if not includes_embed_tokens: + skip_substrs.append("embed_tokens") + + loader = AutoWeightsLoader( + self, + skip_prefixes=None, + skip_substrs=skip_substrs, + ) + loader.load_weights(model_weights.items()) + + +# Aliases for compatibility +Eagle3DeepseekV3ForCausalLM = Eagle3DeepseekV2ForCausalLM diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8277e99fd..a198f1a0b 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -82,7 +82,13 @@ from vllm.v1.attention.backends.mla.indexer import ( ) from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec -from .interfaces import MixtureOfExperts, SupportsEagle, SupportsLoRA, SupportsPP +from .interfaces import ( + MixtureOfExperts, + SupportsEagle, + SupportsEagle3, + SupportsLoRA, + SupportsPP, +) from .utils import ( PPMissingLayer, is_pp_missing_parameter, @@ -828,6 +834,7 @@ class DeepseekV2MLAAttention(nn.Module): quant_config: QuantizationConfig | None = None, prefix: str = "", topk_indices_buffer: torch.Tensor | None = None, + input_size: int | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -847,16 +854,20 @@ class DeepseekV2MLAAttention(nn.Module): self.scaling = self.qk_head_dim**-0.5 self.max_position_embeddings = max_position_embeddings + # Use input_size for projection input dimensions if provided, + # otherwise default to hidden_size (used in Eagle3 Deepseek with MLA) + proj_input_size = input_size if input_size is not None else self.hidden_size + if self.q_lora_rank is not None: self.fused_qkv_a_proj = DeepSeekV2FusedQkvAProjLinear( - self.hidden_size, + proj_input_size, [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], quant_config=quant_config, prefix=f"{prefix}.fused_qkv_a_proj", ) else: self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, + proj_input_size, self.kv_lora_rank + self.qk_rope_head_dim, bias=False, quant_config=quant_config, @@ -874,7 +885,7 @@ class DeepseekV2MLAAttention(nn.Module): ) else: self.q_proj = ColumnParallelLinear( - self.hidden_size, + proj_input_size, self.num_heads * self.qk_head_dim, bias=False, quant_config=quant_config, @@ -1170,6 +1181,8 @@ class DeepseekV2Model(nn.Module): ["hidden_states", "residual"], config.hidden_size ) + self.aux_hidden_state_layers = tuple[int, ...]() + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -1205,7 +1218,13 @@ class DeepseekV2Model(nn.Module): else: llama_4_scaling = None - for layer in islice(self.layers, self.start_layer, self.end_layer): + aux_hidden_states = [] + for idx, layer in enumerate( + islice(self.layers, self.start_layer, self.end_layer), + start=self.start_layer, + ): + if idx in self.aux_hidden_state_layers: + aux_hidden_states.append(hidden_states + residual) hidden_states, residual = layer( positions, hidden_states, residual, llama_4_scaling ) @@ -1216,6 +1235,8 @@ class DeepseekV2Model(nn.Module): ) hidden_states, _ = self.norm(hidden_states, residual) + if len(aux_hidden_states) > 0: + return hidden_states, aux_hidden_states return hidden_states @@ -1261,7 +1282,12 @@ class DeepseekV2MixtureOfExperts(MixtureOfExperts): class DeepseekV2ForCausalLM( - nn.Module, SupportsPP, DeepseekV2MixtureOfExperts, SupportsLoRA, SupportsEagle + nn.Module, + SupportsPP, + DeepseekV2MixtureOfExperts, + SupportsLoRA, + SupportsEagle, + SupportsEagle3, ): packed_modules_mapping = { "gate_up_proj": ["gate_proj", "up_proj"], @@ -1340,6 +1366,13 @@ class DeepseekV2ForCausalLM( self.extract_moe_parameters(example_moe) + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.model.aux_hidden_state_layers = layers + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + num_layers = len(self.model.layers) + return (2, num_layers // 2, num_layers - 3) + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) diff --git a/vllm/model_executor/models/kimi_k25.py b/vllm/model_executor/models/kimi_k25.py index 35c7576c4..2f809f929 100644 --- a/vllm/model_executor/models/kimi_k25.py +++ b/vllm/model_executor/models/kimi_k25.py @@ -28,6 +28,8 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso CompressedTensorsConfig, ) from vllm.model_executor.models.interfaces import ( + SupportsEagle, + SupportsEagle3, SupportsMultiModal, SupportsPP, SupportsQuant, @@ -311,7 +313,12 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo]) dummy_inputs=KimiK25DummyInputsBuilder, ) class KimiK25ForConditionalGeneration( - nn.Module, SupportsMultiModal, SupportsPP, SupportsQuant + nn.Module, + SupportsMultiModal, + SupportsPP, + SupportsQuant, + SupportsEagle, + SupportsEagle3, ): """Kimi-K2.5 model for conditional generation. @@ -480,6 +487,12 @@ class KimiK25ForConditionalGeneration( logits = self.language_model.compute_logits(hidden_states) return logits + def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: + self.language_model.set_aux_hidden_state_layers(layers) + + def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: + return self.language_model.get_eagle3_aux_hidden_state_layers() + def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): loader = AutoWeightsLoader(self) return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 00bfa8c65..d5d3bd265 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -551,6 +551,8 @@ _SPECULATIVE_DECODING_MODELS = { "mistral_large_3_eagle", "EagleMistralLarge3ForCausalLM", ), + "Eagle3DeepseekV2ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"), + "Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"), "EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"), "DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"), "ErnieMTPModel": ("ernie_mtp", "ErnieMTP"), diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index fc8d377da..f03de6015 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -87,6 +87,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( funaudiochat="FunAudioChatConfig", hunyuan_vl="HunYuanVLConfig", isaac="IsaacConfig", + kimi_k2="DeepseekV3Config", # Kimi K2 uses same architecture as DeepSeek V3 kimi_linear="KimiLinearConfig", kimi_vl="KimiVLConfig", kimi_k25="KimiK25Config", diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index b985176dc..445bb403b 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -20,6 +20,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import supports_multimodal +from vllm.model_executor.models.deepseek_eagle3 import Eagle3DeepseekV2ForCausalLM from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.multimodal import MULTIMODAL_REGISTRY @@ -403,7 +404,9 @@ class SpecDecodeBaseProposer: batch_size = common_attn_metadata.batch_size() if self.method == "eagle3": - assert isinstance(self.model, Eagle3LlamaForCausalLM) + assert isinstance( + self.model, (Eagle3LlamaForCausalLM, Eagle3DeepseekV2ForCausalLM) + ) target_hidden_states = self.model.combine_hidden_states( target_hidden_states ) @@ -1278,6 +1281,10 @@ class SpecDecodeBaseProposer: self.model.config.image_token_index = ( target_model.config.vision_config.image_token_id ) + elif self.get_model_name(target_model) == "KimiK25ForConditionalGeneration": + self.model.config.image_token_index = ( + target_model.config.media_placeholder_token_id + ) else: self.model.config.image_token_index = ( target_model.config.image_token_index