feat: Enable engine-level arguments with speculators models (#25250)

Signed-off-by: Rahul Tuli <rtuli@redhat.com>
Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Rahul Tuli
2025-09-21 22:34:45 +05:30
committed by GitHub
parent 0ff8ebb2d7
commit c438b2951c
5 changed files with 128 additions and 85 deletions

View File

@@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
return config
def maybe_override_with_speculators_target_model(
def maybe_override_with_speculators(
model: str,
tokenizer: str,
trust_remote_code: bool,
revision: Optional[str] = None,
vllm_speculative_config: Optional[dict[str, Any]] = None,
**kwargs,
) -> tuple[str, str]:
) -> tuple[str, str, Optional[dict[str, Any]]]:
"""
If running a speculators config, override running model with target model
Resolve model configuration when speculators are detected.
Checks if the provided model is a speculators model and if so, extracts
the target model configuration and builds the speculative config.
Args:
model: Model name or path
tokenizer: Tokenizer name or path
trust_remote_code: Whether to trust remote code
revision: Model revision
vllm_speculative_config: Existing vLLM speculative config
Returns:
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
"""
is_gguf = check_gguf_file(model)
if is_gguf:
@@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model(
token=_get_hf_token(),
**kwargs,
)
spec_config = config_dict.get("speculators_config", None)
# Return the target model
if spec_config is not None:
model = tokenizer = spec_config["verifier"]["name_or_path"]
return model, tokenizer
speculators_config = config_dict.get("speculators_config")
if speculators_config is None:
# No speculators config found, return original values
return model, tokenizer, vllm_speculative_config
# Speculators format detected - process overrides
from vllm.transformers_utils.configs.speculators.base import (
SpeculatorsConfig)
vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config(
config_dict=config_dict)
# Set the draft model to the speculators model
vllm_speculative_config["model"] = model
# Override model and tokenizer with the verifier model from config
verifier_model = speculators_config["verifier"]["name_or_path"]
model = tokenizer = verifier_model
return model, tokenizer, vllm_speculative_config
def get_config(

View File

@@ -24,6 +24,12 @@ class SpeculatorsConfig(PretrainedConfig):
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
**kwargs)
vllm_config = cls.extract_vllm_speculative_config(config_dict)
return cls(**vllm_config)
@classmethod
def extract_vllm_speculative_config(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
speculators_model_type = config_dict.get("speculators_model_type")
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
raise ValueError(
@@ -34,11 +40,12 @@ class SpeculatorsConfig(PretrainedConfig):
# TODO: @dsikka - use speculators pydantic model to validate
cls.validate_speculators_config(config_dict=config_dict)
# Convert from speculators config -> format that can be ingested by vLLM
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
vllm_config = cls.build_vllm_speculative_config(
config_dict=config_dict)
# Apply anything specific to the supported algorithm
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
return cls(**vllm_config)
return vllm_config
@classmethod
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
@@ -60,32 +67,45 @@ class SpeculatorsConfig(PretrainedConfig):
"'transformer_layer_config' must be a dictionary if provided")
@classmethod
def convert_speculators_to_vllm(
def build_vllm_speculative_config(
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
"""
Convert speculators config format to vLLM format.
This method handles the translation of field names and structure
between speculators and vLLM formats.
Returns:
Dictionary with vLLM-compatible configuration
"""
# Currently we only support one proposal method
spec_config = config_dict["speculators_config"]
first_method = spec_config.get("proposal_methods")[0]
num_lookahead_tokens = first_method.get("speculative_tokens")
Build vLLM-compatible speculative configuration from speculators format.
if num_lookahead_tokens is None:
This method extracts and transforms speculative configuration from the
speculators format into the structure expected by vLLM.
Args:
config_dict: Configuration dictionary in speculators format
Returns:
Dictionary with vLLM-compatible speculative configuration
"""
# Extract speculators configuration
spec_config = config_dict["speculators_config"]
# Currently we only support one proposal method
proposal_methods = spec_config.get("proposal_methods")
if not proposal_methods:
raise ValueError("No proposal methods found in speculators config")
first_method = proposal_methods[0]
num_speculative_tokens = first_method.get("speculative_tokens")
if num_speculative_tokens is None:
raise ValueError(
"Missing 'speculative_tokens' in proposal method. "
f"Got: {first_method}")
# Build base vLLM config
# Build base vLLM speculative configuration
vllm_config = {
"method": config_dict.get("speculators_model_type"),
"num_lookahead_tokens": num_lookahead_tokens,
"num_speculative_tokens": num_speculative_tokens,
"target_model": spec_config.get("verifier")["name_or_path"]
}
vllm_config.update(config_dict["transformer_layer_config"])
# Merge transformer layer configuration if present
transformer_config = config_dict.get("transformer_layer_config", {})
vllm_config.update(transformer_config)
return vllm_config