[BugFix]Fix eagle draft_model_config and add tests (#31753)
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
@@ -13,8 +13,10 @@ from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PoolerConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
update_config,
|
||||
)
|
||||
@@ -1105,3 +1107,23 @@ def test_needs_dp_coordination(
|
||||
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
|
||||
|
||||
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
|
||||
|
||||
|
||||
def test_eagle_draft_model_config():
|
||||
"""Test that EagleDraft model config is correctly set."""
|
||||
target_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
|
||||
)
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
num_speculative_tokens=1,
|
||||
target_model_config=target_model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
)
|
||||
draft_model_config = speculative_config.draft_model_config
|
||||
assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.hf_config.model_type == "eagle"
|
||||
assert draft_model_config.hf_text_config.model_type == "eagle"
|
||||
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
|
||||
|
||||
Reference in New Issue
Block a user