[BugFix]Fix eagle draft_model_config and add tests (#31753)
Signed-off-by: Xingyu Liu <charlotteliu12x@gmail.com>
This commit is contained in:
@@ -38,7 +38,7 @@
|
||||
"EagleDeepSeekMTPModel"
|
||||
],
|
||||
"model_type": "eagle",
|
||||
"text_model_type": "deepseek_mtp",
|
||||
"text_model_type": "eagle",
|
||||
"hidden_size": 2560,
|
||||
"total_num_hidden_layers": 1,
|
||||
"total_num_attention_heads": 32,
|
||||
@@ -55,7 +55,7 @@
|
||||
"EagleLlamaForCausalLM"
|
||||
],
|
||||
"model_type": "eagle",
|
||||
"text_model_type": "llama",
|
||||
"text_model_type": "eagle",
|
||||
"hidden_size": 4096,
|
||||
"total_num_hidden_layers": 1,
|
||||
"total_num_attention_heads": 32,
|
||||
@@ -72,7 +72,7 @@
|
||||
"Eagle3LlamaForCausalLM"
|
||||
],
|
||||
"model_type": "eagle",
|
||||
"text_model_type": "llama",
|
||||
"text_model_type": "eagle",
|
||||
"hidden_size": 4096,
|
||||
"total_num_hidden_layers": 1,
|
||||
"total_num_attention_heads": 32,
|
||||
|
||||
@@ -13,8 +13,10 @@ from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import (
|
||||
CompilationConfig,
|
||||
ModelConfig,
|
||||
ParallelConfig,
|
||||
PoolerConfig,
|
||||
SchedulerConfig,
|
||||
SpeculativeConfig,
|
||||
VllmConfig,
|
||||
update_config,
|
||||
)
|
||||
@@ -1105,3 +1107,23 @@ def test_needs_dp_coordination(
|
||||
vllm_config = VllmConfig(model_config=model_config, parallel_config=parallel_config)
|
||||
|
||||
assert vllm_config.needs_dp_coordinator == expected_needs_coordinator
|
||||
|
||||
|
||||
def test_eagle_draft_model_config():
|
||||
"""Test that EagleDraft model config is correctly set."""
|
||||
target_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct", trust_remote_code=True
|
||||
)
|
||||
speculative_config = SpeculativeConfig(
|
||||
model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
|
||||
num_speculative_tokens=1,
|
||||
target_model_config=target_model_config,
|
||||
target_parallel_config=ParallelConfig(),
|
||||
)
|
||||
draft_model_config = speculative_config.draft_model_config
|
||||
assert draft_model_config.hf_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.hf_text_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.hf_config.model_type == "eagle"
|
||||
assert draft_model_config.hf_text_config.model_type == "eagle"
|
||||
assert draft_model_config.architectures == ["EagleLlamaForCausalLM"]
|
||||
assert draft_model_config.architecture == "EagleLlamaForCausalLM"
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_hf_text_config
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||
|
||||
@@ -409,10 +410,23 @@ class SpeculativeConfig:
|
||||
method=self.method,
|
||||
model_type="eagle",
|
||||
)
|
||||
# EAGLEConfig primarily updates architectures, so update
|
||||
# all architectures-related fields in draft_model_config
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
self.draft_model_config.hf_text_config = get_hf_text_config(
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
self.draft_model_config.model_arch_config = (
|
||||
self.draft_model_config.get_model_arch_config()
|
||||
)
|
||||
model_info, arch = (
|
||||
self.draft_model_config.registry.inspect_model_cls(
|
||||
self.draft_model_config.architectures,
|
||||
self.draft_model_config,
|
||||
)
|
||||
)
|
||||
self.draft_model_config._model_info = model_info
|
||||
self.draft_model_config._architecture = arch
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
|
||||
@@ -201,7 +201,7 @@ class ModelArchConfigConvertorBase:
|
||||
# underlying architecture
|
||||
return (
|
||||
self.hf_text_config.model.model_type
|
||||
in ("deepseek_v2", "deepseek_v3", "deepseek_v32")
|
||||
in ("deepseek_v2", "deepseek_v3", "deepseek_v32", "deepseek_mtp")
|
||||
and self.hf_text_config.kv_lora_rank is not None
|
||||
)
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user