[BugFix]Fix eagle draft_model_config and add tests (#31753)

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
Xingyu Liu
2026-01-12 23:09:36 -08:00
committed by GitHub
parent 5e714f7ff4
commit 80221e1884
4 changed files with 40 additions and 4 deletions

View File

@@ -38,7 +38,7 @@
"EagleDeepSeekMTPModel"
],
"model_type": "eagle",
"text_model_type": "deepseek_mtp",
"text_model_type": "eagle",
"hidden_size": 2560,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
@@ -55,7 +55,7 @@
"EagleLlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"text_model_type": "eagle",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,
@@ -72,7 +72,7 @@
"Eagle3LlamaForCausalLM"
],
"model_type": "eagle",
"text_model_type": "llama",
"text_model_type": "eagle",
"hidden_size": 4096,
"total_num_hidden_layers": 1,
"total_num_attention_heads": 32,

View File

@@ -13,8 +13,10 @@ from vllm.compilation.backends import VllmBackend
from vllm.config import (
CompilationConfig,
ModelConfig,
ParallelConfig,
PoolerConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
update_config,
)
@@ -1105,3 +1107,23 @@ def test_needs_dp_coordination(
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
def test_eagle_draft_model_config():
"""Test that EagleDraft model config is correctly set."""
target_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
)
speculative_config = SpeculativeConfig(
model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
num_speculative_tokens=1,
target_model_config=target_model_config,
target_parallel_config=ParallelConfig(),
)
draft_model_config = speculative_config.draft_model_config
assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_config.model_type == "eagle"
assert draft_model_config.hf_text_config.model_type == "eagle"
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.architecture == "EagleLlamaForCausalLM"

View File

@@ -12,6 +12,7 @@ from vllm.config.model import ModelConfig
from vllm.config.parallel import ParallelConfig
from vllm.config.utils import config
from vllm.logger import init_logger
from vllm.transformers_utils.config import get_hf_text_config
from vllm.utils.hashing import safe_hash
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
@@ -409,10 +410,23 @@ class SpeculativeConfig:
method=self.method,
model_type="eagle",
)
# EAGLEConfig primarily updates architectures, so update
# all architectures-related fields in draft_model_config
self.draft_model_config.hf_config = eagle_config
self.draft_model_config.hf_text_config = get_hf_text_config(
self.draft_model_config.hf_config
)
self.draft_model_config.model_arch_config = (
self.draft_model_config.get_model_arch_config()
)
model_info, arch = (
self.draft_model_config.registry.inspect_model_cls(
self.draft_model_config.architectures,
self.draft_model_config,
)
)
self.draft_model_config._model_info = model_info
self.draft_model_config._architecture = arch
if self.num_speculative_tokens is not None and hasattr(
self.draft_model_config.hf_config, "num_lookahead_tokens"

View File

@@ -201,7 +201,7 @@ class ModelArchConfigConvertorBase:
# underlying architecture
return (
self.hf_text_config.model.model_type
in ("deepseek_v2", "deepseek_v3", "deepseek_v32")
in ("deepseek_v2", "deepseek_v3", "deepseek_v32", "deepseek_mtp")
and self.hf_text_config.kv_lora_rank is not None
)
return False