feat: Enable engine-level arguments with speculators models (#25250)
Signed-off-by: Rahul Tuli <rtuli@redhat.com> Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -463,15 +463,29 @@ def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||
return config
|
||||
|
||||
|
||||
def maybe_override_with_speculators_target_model(
|
||||
def maybe_override_with_speculators(
|
||||
model: str,
|
||||
tokenizer: str,
|
||||
trust_remote_code: bool,
|
||||
revision: Optional[str] = None,
|
||||
vllm_speculative_config: Optional[dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[str, str]:
|
||||
) -> tuple[str, str, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
If running a speculators config, override running model with target model
|
||||
Resolve model configuration when speculators are detected.
|
||||
|
||||
Checks if the provided model is a speculators model and if so, extracts
|
||||
the target model configuration and builds the speculative config.
|
||||
|
||||
Args:
|
||||
model: Model name or path
|
||||
tokenizer: Tokenizer name or path
|
||||
trust_remote_code: Whether to trust remote code
|
||||
revision: Model revision
|
||||
vllm_speculative_config: Existing vLLM speculative config
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_model, resolved_tokenizer, speculative_config)
|
||||
"""
|
||||
is_gguf = check_gguf_file(model)
|
||||
if is_gguf:
|
||||
@@ -487,11 +501,27 @@ def maybe_override_with_speculators_target_model(
|
||||
token=_get_hf_token(),
|
||||
**kwargs,
|
||||
)
|
||||
spec_config = config_dict.get("speculators_config", None)
|
||||
# Return the target model
|
||||
if spec_config is not None:
|
||||
model = tokenizer = spec_config["verifier"]["name_or_path"]
|
||||
return model, tokenizer
|
||||
speculators_config = config_dict.get("speculators_config")
|
||||
|
||||
if speculators_config is None:
|
||||
# No speculators config found, return original values
|
||||
return model, tokenizer, vllm_speculative_config
|
||||
|
||||
# Speculators format detected - process overrides
|
||||
from vllm.transformers_utils.configs.speculators.base import (
|
||||
SpeculatorsConfig)
|
||||
|
||||
vllm_speculative_config = SpeculatorsConfig.extract_vllm_speculative_config(
|
||||
config_dict=config_dict)
|
||||
|
||||
# Set the draft model to the speculators model
|
||||
vllm_speculative_config["model"] = model
|
||||
|
||||
# Override model and tokenizer with the verifier model from config
|
||||
verifier_model = speculators_config["verifier"]["name_or_path"]
|
||||
model = tokenizer = verifier_model
|
||||
|
||||
return model, tokenizer, vllm_speculative_config
|
||||
|
||||
|
||||
def get_config(
|
||||
|
||||
@@ -24,6 +24,12 @@ class SpeculatorsConfig(PretrainedConfig):
|
||||
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path,
|
||||
**kwargs)
|
||||
|
||||
vllm_config = cls.extract_vllm_speculative_config(config_dict)
|
||||
return cls(**vllm_config)
|
||||
|
||||
@classmethod
|
||||
def extract_vllm_speculative_config(
|
||||
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
speculators_model_type = config_dict.get("speculators_model_type")
|
||||
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
|
||||
raise ValueError(
|
||||
@@ -34,11 +40,12 @@ class SpeculatorsConfig(PretrainedConfig):
|
||||
# TODO: @dsikka - use speculators pydantic model to validate
|
||||
cls.validate_speculators_config(config_dict=config_dict)
|
||||
# Convert from speculators config -> format that can be ingested by vLLM
|
||||
vllm_config = cls.convert_speculators_to_vllm(config_dict=config_dict)
|
||||
vllm_config = cls.build_vllm_speculative_config(
|
||||
config_dict=config_dict)
|
||||
# Apply anything specific to the supported algorithm
|
||||
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
||||
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
|
||||
return cls(**vllm_config)
|
||||
return vllm_config
|
||||
|
||||
@classmethod
|
||||
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
|
||||
@@ -60,32 +67,45 @@ class SpeculatorsConfig(PretrainedConfig):
|
||||
"'transformer_layer_config' must be a dictionary if provided")
|
||||
|
||||
@classmethod
|
||||
def convert_speculators_to_vllm(
|
||||
def build_vllm_speculative_config(
|
||||
cls, config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Convert speculators config format to vLLM format.
|
||||
|
||||
This method handles the translation of field names and structure
|
||||
between speculators and vLLM formats.
|
||||
|
||||
Returns:
|
||||
Dictionary with vLLM-compatible configuration
|
||||
"""
|
||||
# Currently we only support one proposal method
|
||||
spec_config = config_dict["speculators_config"]
|
||||
first_method = spec_config.get("proposal_methods")[0]
|
||||
num_lookahead_tokens = first_method.get("speculative_tokens")
|
||||
Build vLLM-compatible speculative configuration from speculators format.
|
||||
|
||||
if num_lookahead_tokens is None:
|
||||
This method extracts and transforms speculative configuration from the
|
||||
speculators format into the structure expected by vLLM.
|
||||
|
||||
Args:
|
||||
config_dict: Configuration dictionary in speculators format
|
||||
|
||||
Returns:
|
||||
Dictionary with vLLM-compatible speculative configuration
|
||||
"""
|
||||
# Extract speculators configuration
|
||||
spec_config = config_dict["speculators_config"]
|
||||
|
||||
# Currently we only support one proposal method
|
||||
proposal_methods = spec_config.get("proposal_methods")
|
||||
if not proposal_methods:
|
||||
raise ValueError("No proposal methods found in speculators config")
|
||||
|
||||
first_method = proposal_methods[0]
|
||||
num_speculative_tokens = first_method.get("speculative_tokens")
|
||||
|
||||
if num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"Missing 'speculative_tokens' in proposal method. "
|
||||
f"Got: {first_method}")
|
||||
|
||||
# Build base vLLM config
|
||||
# Build base vLLM speculative configuration
|
||||
vllm_config = {
|
||||
"method": config_dict.get("speculators_model_type"),
|
||||
"num_lookahead_tokens": num_lookahead_tokens,
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
"target_model": spec_config.get("verifier")["name_or_path"]
|
||||
}
|
||||
vllm_config.update(config_dict["transformer_layer_config"])
|
||||
|
||||
# Merge transformer layer configuration if present
|
||||
transformer_config = config_dict.get("transformer_layer_config", {})
|
||||
vllm_config.update(transformer_config)
|
||||
|
||||
return vllm_config
|
||||
|
||||
Reference in New Issue
Block a user