[Speculative Decoding] Add speculators config support (#21345)
This commit is contained in:
@@ -978,8 +978,28 @@ class EngineArgs:
|
||||
provided as a JSON string input via CLI arguments or directly as a
|
||||
dictionary from the engine.
|
||||
"""
|
||||
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.transformers_utils.configs.speculators.base import (
|
||||
SpeculatorsConfig)
|
||||
|
||||
if self.speculative_config is None:
|
||||
return None
|
||||
hf_config = get_config(self.hf_config_path or self.model,
|
||||
self.trust_remote_code, self.revision,
|
||||
self.code_revision, self.config_format)
|
||||
|
||||
# if loading a SpeculatorsConfig, load the specualtive_config
|
||||
# details from the config directly
|
||||
# no user input required / expected
|
||||
if isinstance(hf_config, SpeculatorsConfig):
|
||||
# We create one since we dont create one
|
||||
self.speculative_config = {}
|
||||
self.speculative_config[
|
||||
"num_speculative_tokens"] = hf_config.num_lookahead_tokens
|
||||
self.speculative_config["model"] = self.model
|
||||
self.speculative_config["method"] = hf_config.method
|
||||
else:
|
||||
return None
|
||||
|
||||
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||
# '--speculative-config' and must be passed in when creating the engine
|
||||
|
||||
Reference in New Issue
Block a user