[Speculative Decoding] Add speculators config support (#21345)
This commit is contained in:
@@ -35,8 +35,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
|
||||
MllamaConfig, MLPSpeculatorConfig,
|
||||
Nemotron_Nano_VL_Config,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
RWConfig, Step3TextConfig,
|
||||
Step3VLConfig, UltravoxConfig)
|
||||
RWConfig, SpeculatorsConfig,
|
||||
Step3TextConfig, Step3VLConfig,
|
||||
UltravoxConfig)
|
||||
# yapf: enable
|
||||
from vllm.transformers_utils.configs.mistral import adapt_config_dict
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
@@ -81,6 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
||||
"mlp_speculator": MLPSpeculatorConfig,
|
||||
"medusa": MedusaConfig,
|
||||
"eagle": EAGLEConfig,
|
||||
"speculators": SpeculatorsConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
"NVLM_D": NVLM_D_Config,
|
||||
"ultravox": UltravoxConfig,
|
||||
@@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||
return config
|
||||
|
||||
|
||||
def maybe_override_with_speculators_target_model(
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None) -> tuple[str, str]:
|
||||
"""
|
||||
If running a speculators config, override running model with target model
|
||||
"""
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
model,
|
||||
revision=revision,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token=_get_hf_token(),
|
||||
)
|
||||
spec_config = config_dict.get("speculators_config")
|
||||
# Return the target model
|
||||
if spec_config is not None:
|
||||
model = tokenizer = spec_config["verifier"]["name_or_path"]
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def get_config(
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
@@ -345,9 +368,12 @@ def get_config(
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Use custom model class if it's in our registry
|
||||
model_type = config_dict.get("model_type")
|
||||
if model_type is None:
|
||||
model_type = "speculators" if config_dict.get(
|
||||
"speculators_config") is not None else model_type
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(
|
||||
|
||||
Reference in New Issue
Block a user