[Speculative Decoding] Add speculators config support (#21345)

This commit is contained in:
Dipika Sikka
2025-08-01 08:25:18 -04:00
committed by GitHub
parent 87c94bc879
commit dfbc1f8880
9 changed files with 232 additions and 11 deletions

View File

@@ -35,8 +35,9 @@ from vllm.transformers_utils.configs import (ChatGLMConfig, DeepseekVLV2Config,
MllamaConfig, MLPSpeculatorConfig,
Nemotron_Nano_VL_Config,
NemotronConfig, NVLM_D_Config,
RWConfig, Step3TextConfig,
Step3VLConfig, UltravoxConfig)
RWConfig, SpeculatorsConfig,
Step3TextConfig, Step3VLConfig,
UltravoxConfig)
# yapf: enable
from vllm.transformers_utils.configs.mistral import adapt_config_dict
from vllm.transformers_utils.utils import check_gguf_file
@@ -81,6 +82,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"mlp_speculator": MLPSpeculatorConfig,
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"speculators": SpeculatorsConfig,
"nemotron": NemotronConfig,
"NVLM_D": NVLM_D_Config,
"ultravox": UltravoxConfig,
@@ -287,6 +289,27 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
return config
def maybe_override_with_speculators_target_model(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> tuple[str, str]:
"""
If running a speculators config, override running model with target model
"""
config_dict, _ = PretrainedConfig.get_config_dict(
model,
revision=revision,
trust_remote_code=trust_remote_code,
token=_get_hf_token(),
)
spec_config = config_dict.get("speculators_config")
# Return the target model
if spec_config is not None:
model = tokenizer = spec_config["verifier"]["name_or_path"]
return model, tokenizer
def get_config(
model: Union[str, Path],
trust_remote_code: bool,
@@ -345,9 +368,12 @@ def get_config(
token=_get_hf_token(),
**kwargs,
)
# Use custom model class if it's in our registry
model_type = config_dict.get("model_type")
if model_type is None:
model_type = "speculators" if config_dict.get(
"speculators_config") is not None else model_type
if model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(