[BugFix]Fix eagle draft_model_config and add tests (#31753)

Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
Xingyu Liu
2026-01-12 23:09:36 -08:00
committed by GitHub
parent 5e714f7ff4
commit 80221e1884
4 changed files with 40 additions and 4 deletions

View File

@@ -13,8 +13,10 @@ from vllm.compilation.backends import VllmBackend
from vllm.config import (
CompilationConfig,
ModelConfig,
ParallelConfig,
PoolerConfig,
SchedulerConfig,
SpeculativeConfig,
VllmConfig,
update_config,
)
@@ -1105,3 +1107,23 @@ def test_needs_dp_coordination(
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
def test_eagle_draft_model_config():
"""Test that EagleDraft model config is correctly set."""
target_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
)
speculative_config = SpeculativeConfig(
model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
num_speculative_tokens=1,
target_model_config=target_model_config,
target_parallel_config=ParallelConfig(),
)
draft_model_config = speculative_config.draft_model_config
assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.hf_config.model_type == "eagle"
assert draft_model_config.hf_text_config.model_type == "eagle"
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
assert draft_model_config.architecture == "EagleLlamaForCausalLM"