[V1][Usage] Refactor speculative decoding configuration and tests (#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
596
vllm/config.py
596
vllm/config.py
@@ -1810,12 +1810,139 @@ class DeviceConfig:
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding.
|
||||
|
||||
The configuration is currently specialized to draft-model speculative
|
||||
decoding with top-1 proposals.
|
||||
"""
|
||||
Configuration for speculative decoding.
|
||||
Configurable parameters include:
|
||||
- General Speculative Decoding Control:
|
||||
- num_speculative_tokens (int): The number of speculative
|
||||
tokens, if provided. It will default to the number in the draft
|
||||
model config if present, otherwise, it is required.
|
||||
- model (Optional[str]): The name of the draft model, eagle head,
|
||||
or additional weights, if provided.
|
||||
- method (Optional[str]): The name of the speculative method to use.
|
||||
If users provide and set the `model` param, the speculative method
|
||||
type will be detected automatically if possible, if `model` param
|
||||
is not provided, the method name must be provided.
|
||||
- Possible values:
|
||||
- ngram
|
||||
Related additional configuration:
|
||||
- prompt_lookup_max (Optional[int]):
|
||||
Maximum size of ngram token window when using Ngram
|
||||
proposer, required when method is set to ngram.
|
||||
- prompt_lookup_min (Optional[int]):
|
||||
Minimum size of ngram token window when using Ngram
|
||||
proposer, if provided. Defaults to 1.
|
||||
- eagle
|
||||
- medusa
|
||||
- mlp_speculator
|
||||
- draft_model
|
||||
- acceptance_method (str): The method to use for accepting draft
|
||||
tokens. This can take two possible values: 'rejection_sampler' and
|
||||
'typical_acceptance_sampler' for RejectionSampler and
|
||||
TypicalAcceptanceSampler respectively. If not specified, it
|
||||
defaults to 'rejection_sampler'.
|
||||
- Possible values:
|
||||
- rejection_sampler
|
||||
- typical_acceptance_sampler
|
||||
Related additional configuration:
|
||||
- posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the
|
||||
posterior probability of a token in the target model
|
||||
for it to be accepted. This threshold is used only
|
||||
when we use the TypicalAcceptanceSampler for token
|
||||
acceptance.
|
||||
- posterior_alpha (Optional[float]):
|
||||
Scaling factor for entropy-based threshold, applied
|
||||
when using TypicalAcceptanceSampler.
|
||||
- draft_tensor_parallel_size (Optional[int]): The degree of the tensor
|
||||
parallelism for the draft model. Can only be 1 or the same as the
|
||||
target model's tensor parallel size.
|
||||
- disable_logprobs (bool): If set to True, token log probabilities are
|
||||
not returned during speculative decoding. If set to False, token
|
||||
log probabilities are returned according to the log probability
|
||||
settings in SamplingParams. If not specified, it defaults to True.
|
||||
|
||||
- Draft Model Configuration:
|
||||
- quantization (Optional[str]): Quantization method that was used to
|
||||
quantize the draft model weights. If None, we assume the
|
||||
model weights are not quantized. Note that it only takes effect
|
||||
when using the draft model-based speculative method.
|
||||
- max_model_len (Optional[int]): The maximum model length of the
|
||||
draft model. Used when testing the ability to skip
|
||||
speculation for some sequences.
|
||||
- revision: The specific model version to use for the draft model. It
|
||||
can be a branch name, a tag name, or a commit id. If unspecified,
|
||||
will use the default version.
|
||||
- code_revision: The specific revision to use for the draft model code
|
||||
on Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||
commit id. If unspecified, will use the default version.
|
||||
|
||||
- Advanced Control:
|
||||
- disable_mqa_scorer (bool): Disable the MQA scorer and fall back to
|
||||
batch expansion for scoring proposals. If not specified, it
|
||||
defaults to False.
|
||||
- disable_by_batch_size (Optional[int]): Disable speculative decoding
|
||||
for new incoming requests when the number of enqueued requests is
|
||||
larger than this value, if provided.
|
||||
|
||||
Although the parameters above are structured hierarchically, there is no
|
||||
need to nest them during configuration.
|
||||
|
||||
Non-configurable internal parameters include:
|
||||
- Model Configuration:
|
||||
- target_model_config (ModelConfig): The configuration of the target
|
||||
model.
|
||||
- draft_model_config (ModelConfig): The configuration of the draft
|
||||
model initialized internal.
|
||||
- Parallelism Configuration:
|
||||
- target_parallel_config (ParallelConfig): The parallel configuration
|
||||
for the target model.
|
||||
- draft_parallel_config (ParallelConfig): The parallel configuration
|
||||
for the draft model initialized internal.
|
||||
- Execution Control:
|
||||
- enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||
chunked prefill or not. Used for raising an error since it's not
|
||||
yet compatible with speculative decode.
|
||||
- disable_log_stats (bool): Whether to disable the periodic printing of
|
||||
stage times in speculative decoding.
|
||||
"""
|
||||
# speculative configs from cli args
|
||||
num_speculative_tokens: int = field(default=None,
|
||||
init=True) # type: ignore
|
||||
method: Optional[str] = None
|
||||
acceptance_method: str = "rejection_sampler"
|
||||
draft_tensor_parallel_size: Optional[int] = None
|
||||
disable_logprobs: bool = True
|
||||
|
||||
model: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
max_model_len: Optional[int] = None
|
||||
revision: Optional[str] = None
|
||||
code_revision: Optional[str] = None
|
||||
|
||||
disable_mqa_scorer: bool = False
|
||||
disable_by_batch_size: Optional[int] = None
|
||||
prompt_lookup_max: Optional[int] = None
|
||||
prompt_lookup_min: Optional[int] = None
|
||||
posterior_threshold: Optional[float] = None
|
||||
posterior_alpha: Optional[float] = None
|
||||
|
||||
# required configuration params passed from engine
|
||||
target_model_config: ModelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
target_parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
enable_chunked_prefill: bool = field(default=None,
|
||||
init=True) # type: ignore
|
||||
disable_log_stats: bool = field(default=None, init=True) # type: ignore
|
||||
|
||||
# params generated in the post-init stage
|
||||
draft_model_config: ModelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
draft_parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@@ -1835,6 +1962,11 @@ class SpeculativeConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode()).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dict_value: dict) -> "SpeculativeConfig":
|
||||
"""Parse the CLI value for the speculative config."""
|
||||
return cls(**dict_value)
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type == "deepseek_v3":
|
||||
@@ -1847,230 +1979,160 @@ class SpeculativeConfig:
|
||||
})
|
||||
return hf_config
|
||||
|
||||
@staticmethod
|
||||
def maybe_create_spec_config(
|
||||
target_model_config: ModelConfig,
|
||||
target_parallel_config: ParallelConfig,
|
||||
target_dtype: str,
|
||||
speculative_model: Optional[str],
|
||||
speculative_model_quantization: Optional[str],
|
||||
speculative_draft_tensor_parallel_size: Optional[int],
|
||||
num_speculative_tokens: Optional[int],
|
||||
speculative_disable_mqa_scorer: Optional[bool],
|
||||
speculative_max_model_len: Optional[int],
|
||||
enable_chunked_prefill: bool,
|
||||
disable_log_stats: bool,
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float],
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float],
|
||||
disable_logprobs: Optional[bool],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
def __post_init__(self):
|
||||
|
||||
This function attempts to create a SpeculativeConfig object based on the
|
||||
provided parameters. If the necessary conditions are met, it returns an
|
||||
instance of SpeculativeConfig. Otherwise, it returns None.
|
||||
# Note: After next release, the method parameter will be used to
|
||||
# specify the speculative method, which helps to extend the
|
||||
# configuration of non-model-based proposers, and the model parameter
|
||||
# will be used when the draft model or head is needed.
|
||||
# If users do not specify the method, the speculative method will
|
||||
# be detected automatically if possible. If the speculative method can
|
||||
# not be detected, it will be considered as the draft-model-based
|
||||
# method by default.
|
||||
|
||||
Args:
|
||||
target_model_config (ModelConfig): The configuration of the target
|
||||
model.
|
||||
target_parallel_config (ParallelConfig): The parallel configuration
|
||||
for the target model.
|
||||
target_dtype (str): The data type used for the target model.
|
||||
speculative_model (Optional[str]): The name of the speculative
|
||||
model, if provided.
|
||||
speculative_model_quantization (Optional[str]): Quantization method
|
||||
that was used to quantize the speculative model weights. If
|
||||
None, we assume the model weights are not quantized.
|
||||
speculative_draft_tensor_parallel_size (Optional[int]): The degree
|
||||
of the tensor parallelism for the draft model.
|
||||
num_speculative_tokens (Optional[int]): The number of speculative
|
||||
tokens, if provided. Will default to the number in the draft
|
||||
model config if present, otherwise is required.
|
||||
speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA
|
||||
scorer for the speculative model and fall back to batch
|
||||
expansion for scoring.
|
||||
speculative_max_model_len (Optional[int]): The maximum model len of
|
||||
the speculative model. Used when testing the ability to skip
|
||||
speculation for some sequences.
|
||||
enable_chunked_prefill (bool): Whether vLLM is configured to use
|
||||
chunked prefill or not. Used for raising an error since its not
|
||||
yet compatible with spec decode.
|
||||
speculative_disable_by_batch_size (Optional[int]): Disable
|
||||
speculative decoding for new incoming requests when the number
|
||||
of enqueue requests is larger than this value, if provided.
|
||||
ngram_prompt_lookup_max (Optional[int]): Max size of ngram token
|
||||
window, if provided.
|
||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||
window, if provided.
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
disable_logprobs (Optional[bool]): If set to True, token log
|
||||
probabilities are not returned during speculative decoding.
|
||||
If set to False, token log probabilities are returned
|
||||
according to the log probability settings in SamplingParams.
|
||||
If not specified, it defaults to True.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
"""
|
||||
if speculative_model is None:
|
||||
if num_speculative_tokens is not None:
|
||||
if target_model_config.hf_text_config.model_type \
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
# mtp acceleration for more models besides deepseek_v3
|
||||
if self.target_model_config.hf_text_config.model_type \
|
||||
== "deepseek_v3":
|
||||
# use the draft model from the same model:
|
||||
speculative_model = target_model_config.model
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided without "
|
||||
"speculative_model.")
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
else:
|
||||
return None
|
||||
raise ValueError("num_speculative_tokens was provided without "
|
||||
"speculative model.")
|
||||
|
||||
if (speculative_disable_by_batch_size is not None
|
||||
and speculative_disable_by_batch_size < 2):
|
||||
raise ValueError("Expect the batch size threshold of disabling "
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{speculative_disable_by_batch_size=}")
|
||||
if (enable_chunked_prefill and speculative_model == "eagle"):
|
||||
raise ValueError("Chunked prefill and EAGLE are not compatible.")
|
||||
# TODO: The user should be able to specify revision/max model len
|
||||
# for the draft model. It is not currently supported.
|
||||
draft_revision = None
|
||||
draft_code_revision = None
|
||||
draft_quantization = speculative_model_quantization
|
||||
# Automatically configure the ngram method during configuration
|
||||
# refactoring to ensure a smooth transition.
|
||||
if self.method is None and (self.model is not None
|
||||
and self.model in ("ngram", "[ngram]")):
|
||||
self.method = "ngram"
|
||||
|
||||
if speculative_model == "[ngram]":
|
||||
if ngram_prompt_lookup_min is None:
|
||||
ngram_prompt_lookup_min = 1
|
||||
if ngram_prompt_lookup_max is None or ngram_prompt_lookup_max < 1:
|
||||
raise ValueError(f"{ngram_prompt_lookup_max=} must be > 0")
|
||||
if ngram_prompt_lookup_min < 1:
|
||||
raise ValueError(f"{ngram_prompt_lookup_min=} must be > 0")
|
||||
if ngram_prompt_lookup_min > ngram_prompt_lookup_max:
|
||||
raise ValueError(f"{ngram_prompt_lookup_min=} cannot be "
|
||||
f"larger than {ngram_prompt_lookup_max=}")
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
if self.prompt_lookup_min is None:
|
||||
self.prompt_lookup_min = 1
|
||||
if self.prompt_lookup_max is None or self.prompt_lookup_max < 1:
|
||||
raise ValueError("prompt_lookup_max="
|
||||
f"{self.prompt_lookup_max} must be > 0")
|
||||
if self.prompt_lookup_min < 1:
|
||||
raise ValueError("prompt_lookup_min="
|
||||
f"{self.prompt_lookup_min} must be > 0")
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(f"prompt_lookup_min={self.prompt_lookup_min} "
|
||||
"cannot be larger than prompt_lookup_max="
|
||||
f"{self.prompt_lookup_max}")
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
draft_model_config = target_model_config
|
||||
draft_parallel_config = target_parallel_config
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
else:
|
||||
ngram_prompt_lookup_max = 0
|
||||
ngram_prompt_lookup_min = 0
|
||||
draft_model_config = ModelConfig(
|
||||
model=speculative_model,
|
||||
task="draft",
|
||||
tokenizer=target_model_config.tokenizer,
|
||||
tokenizer_mode=target_model_config.tokenizer_mode,
|
||||
trust_remote_code=target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=target_model_config.
|
||||
allowed_local_media_path,
|
||||
dtype=target_model_config.dtype,
|
||||
seed=target_model_config.seed,
|
||||
revision=draft_revision,
|
||||
code_revision=draft_code_revision,
|
||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||
max_model_len=None,
|
||||
spec_target_max_model_len=target_model_config.max_model_len,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_seq_len_to_capture=target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
draft_hf_config = draft_model_config.hf_config
|
||||
if self.model is not None:
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
task="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.
|
||||
trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.
|
||||
tokenizer_revision,
|
||||
max_model_len=None,
|
||||
spec_target_max_model_len=self.target_model_config.
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_seq_len_to_capture=self.target_model_config.
|
||||
max_seq_len_to_capture,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
# Detect EAGLE prefix to replace hf_config for EAGLE draft_model
|
||||
if "eagle-" in draft_model_config.model.lower():
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
if isinstance(draft_model_config.hf_config, EAGLEConfig):
|
||||
pass
|
||||
# Automatically detect the method
|
||||
if "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
else:
|
||||
eagle_config = EAGLEConfig(draft_model_config.hf_config)
|
||||
draft_model_config.hf_config = eagle_config
|
||||
self.method = "draft_model"
|
||||
|
||||
if (num_speculative_tokens is not None
|
||||
and hasattr(draft_hf_config, "num_lookahead_tokens")):
|
||||
draft_hf_config.num_lookahead_tokens = num_speculative_tokens
|
||||
n_predict = getattr(draft_hf_config, "n_predict", None)
|
||||
if n_predict is not None:
|
||||
if num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
num_speculative_tokens = n_predict
|
||||
elif num_speculative_tokens > n_predict and \
|
||||
num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"{num_speculative_tokens=} must be divisible by "
|
||||
f"{n_predict=}")
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method == "eagle":
|
||||
if self.enable_chunked_prefill:
|
||||
raise ValueError(
|
||||
"Chunked prefill and EAGLE are not compatible.")
|
||||
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_model_tensor_parallel_size(
|
||||
target_parallel_config,
|
||||
speculative_draft_tensor_parallel_size,
|
||||
draft_hf_config
|
||||
)
|
||||
from vllm.transformers_utils.configs.eagle import (
|
||||
EAGLEConfig)
|
||||
if isinstance(self.draft_model_config.hf_config,
|
||||
EAGLEConfig):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config)
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
|
||||
draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len,
|
||||
draft_model_config.max_model_len,
|
||||
target_model_config.max_model_len,
|
||||
))
|
||||
if (self.num_speculative_tokens is not None
|
||||
and hasattr(self.draft_model_config.hf_config,
|
||||
"num_lookahead_tokens")):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||
self.num_speculative_tokens
|
||||
|
||||
draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
target_parallel_config,
|
||||
speculative_draft_tensor_parallel_size, draft_hf_config))
|
||||
n_predict = getattr(self.draft_model_config.hf_config,
|
||||
"n_predict", None)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif self.num_speculative_tokens > n_predict and \
|
||||
self.num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}")
|
||||
|
||||
if num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
"speculative_model unless the draft model config contains an "
|
||||
"n_predict parameter.")
|
||||
self.draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
|
||||
if typical_acceptance_sampler_posterior_threshold is None:
|
||||
typical_acceptance_sampler_posterior_threshold = 0.09
|
||||
if typical_acceptance_sampler_posterior_alpha is None:
|
||||
typical_acceptance_sampler_posterior_alpha = 0.3
|
||||
if disable_logprobs is None:
|
||||
disable_logprobs = True
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
))
|
||||
|
||||
return SpeculativeConfig(
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
num_speculative_tokens,
|
||||
speculative_disable_mqa_scorer,
|
||||
speculative_disable_by_batch_size,
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=draft_token_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=\
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=disable_logprobs,
|
||||
disable_log_stats=disable_log_stats,
|
||||
)
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size))
|
||||
|
||||
if self.acceptance_method == "typical_acceptance_sampler":
|
||||
if self.posterior_threshold is None:
|
||||
self.posterior_threshold = 0.09
|
||||
if self.posterior_alpha is None:
|
||||
self.posterior_alpha = 0.3
|
||||
|
||||
self._verify_args()
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
@@ -2108,7 +2170,7 @@ class SpeculativeConfig:
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_and_get_draft_model_tensor_parallel_size(
|
||||
def _verify_and_get_draft_tp(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: Optional[int],
|
||||
draft_hf_config: PretrainedConfig) -> int:
|
||||
@@ -2140,7 +2202,6 @@ class SpeculativeConfig:
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
draft_hf_config: PretrainedConfig,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
@@ -2164,74 +2225,13 @@ class SpeculativeConfig:
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
draft_model_config: ModelConfig,
|
||||
draft_parallel_config: ParallelConfig,
|
||||
num_speculative_tokens: int,
|
||||
speculative_disable_mqa_scorer: Optional[bool],
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
disable_logprobs: bool,
|
||||
disable_log_stats: bool,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
Args:
|
||||
draft_model_config: ModelConfig for the draft model.
|
||||
draft_parallel_config: ParallelConfig for the draft model.
|
||||
num_speculative_tokens: The number of tokens to sample from the
|
||||
draft model before scoring with the target model.
|
||||
speculative_disable_by_batch_size: Disable speculative
|
||||
decoding for new incoming requests when the number of
|
||||
enqueue requests is larger than this value.
|
||||
ngram_prompt_lookup_max: Max size of ngram token window.
|
||||
ngram_prompt_lookup_min: Min size of ngram token window.
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
disable_logprobs: If set to True, token log probabilities will not
|
||||
be returned even if requested by sampling parameters. This
|
||||
reduces latency by skipping logprob calculation in proposal
|
||||
sampling, target sampling, and after accepted tokens are
|
||||
determined. If set to False, log probabilities will be
|
||||
returned.
|
||||
disable_log_stats: Whether to disable periodic printing of stage
|
||||
times in speculative decoding.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
self.num_speculative_tokens = num_speculative_tokens
|
||||
self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer
|
||||
self.speculative_disable_by_batch_size = \
|
||||
speculative_disable_by_batch_size
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
||||
self.draft_token_acceptance_method = draft_token_acceptance_method
|
||||
self.typical_acceptance_sampler_posterior_threshold = \
|
||||
typical_acceptance_sampler_posterior_threshold
|
||||
self.typical_acceptance_sampler_posterior_alpha = \
|
||||
typical_acceptance_sampler_posterior_alpha
|
||||
self.disable_logprobs = disable_logprobs
|
||||
self.disable_log_stats = disable_log_stats
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
"speculative model unless the draft model config contains an "
|
||||
"n_predict parameter.")
|
||||
|
||||
if self.num_speculative_tokens <= 0:
|
||||
raise ValueError("Expected num_speculative_tokens to be greater "
|
||||
f"than zero ({self.num_speculative_tokens}).")
|
||||
@@ -2241,29 +2241,34 @@ class SpeculativeConfig:
|
||||
self.draft_parallel_config)
|
||||
# Validate and set draft token acceptance related settings.
|
||||
|
||||
if (self.draft_token_acceptance_method is None):
|
||||
raise ValueError("draft_token_acceptance_method is not set. "
|
||||
if self.acceptance_method is None:
|
||||
raise ValueError("acceptance_method is not set. "
|
||||
"Expected values are rejection_sampler or "
|
||||
"typical_acceptance_sampler.")
|
||||
|
||||
if (self.draft_token_acceptance_method != 'rejection_sampler'
|
||||
and self.draft_token_acceptance_method
|
||||
!= 'typical_acceptance_sampler'):
|
||||
if (self.acceptance_method != 'rejection_sampler'
|
||||
and self.acceptance_method != 'typical_acceptance_sampler'):
|
||||
raise ValueError(
|
||||
"Expected draft_token_acceptance_method to be either "
|
||||
"Expected acceptance_method to be either "
|
||||
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
||||
f"is {self.draft_token_acceptance_method}")
|
||||
f"is {self.acceptance_method}")
|
||||
|
||||
if (self.typical_acceptance_sampler_posterior_threshold < 0
|
||||
or self.typical_acceptance_sampler_posterior_alpha < 0):
|
||||
if self.acceptance_method == "typical_acceptance_sampler" and (
|
||||
(self.posterior_threshold is not None
|
||||
and self.posterior_threshold < 0) or
|
||||
(self.posterior_alpha is not None and self.posterior_alpha < 0)):
|
||||
raise ValueError(
|
||||
"Expected typical_acceptance_sampler_posterior_threshold "
|
||||
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
|
||||
"Instead found "
|
||||
f"typical_acceptance_sampler_posterior_threshold = "
|
||||
f"{self.typical_acceptance_sampler_posterior_threshold} and "
|
||||
f"typical_acceptance_sampler_posterior_alpha = "
|
||||
f"{self.typical_acceptance_sampler_posterior_alpha}")
|
||||
"Expected the posterior_threshold and posterior_alpha of "
|
||||
"typical_acceptance_sampler to be > 0. "
|
||||
"Instead found posterior_threshold = "
|
||||
f"{self.posterior_threshold} and posterior_alpha = "
|
||||
f"{self.posterior_alpha}")
|
||||
|
||||
if (self.disable_by_batch_size is not None
|
||||
and self.disable_by_batch_size < 2):
|
||||
raise ValueError("Expect the batch size threshold of disabling "
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{self.disable_by_batch_size=}")
|
||||
|
||||
@property
|
||||
def num_lookahead_slots(self) -> int:
|
||||
@@ -2276,8 +2281,8 @@ class SpeculativeConfig:
|
||||
return self.num_speculative_tokens
|
||||
|
||||
def __repr__(self) -> str:
|
||||
if self.ngram_prompt_lookup_max > 0:
|
||||
draft_model = "[ngram]"
|
||||
if self.prompt_lookup_max is not None and self.prompt_lookup_max > 0:
|
||||
draft_model = "ngram"
|
||||
else:
|
||||
draft_model = self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
@@ -3285,7 +3290,8 @@ class VllmConfig:
|
||||
init=True) # type: ignore
|
||||
load_config: LoadConfig = field(default=None, init=True) # type: ignore
|
||||
lora_config: Optional[LoRAConfig] = None
|
||||
speculative_config: Optional[SpeculativeConfig] = None
|
||||
speculative_config: SpeculativeConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
decoding_config: Optional[DecodingConfig] = None
|
||||
observability_config: Optional[ObservabilityConfig] = None
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig] = None
|
||||
|
||||
Reference in New Issue
Block a user