[V1][Spec Decode] Remove deprecated spec decode config params (#15466)

Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
shangmingc
2025-04-01 00:19:35 +08:00
committed by GitHub
parent 09e974d483
commit 239b7befdd
10 changed files with 125 additions and 220 deletions

View File

@@ -181,22 +181,7 @@ class EngineArgs:
guided_decoding_backend: str = 'xgrammar'
logits_processor_pattern: Optional[str] = None
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
# TODO(Shangming): Deprecate these out-of-date params after next release
speculative_model: Optional[str] = None
speculative_model_quantization: Optional[str] = None
speculative_draft_tensor_parallel_size: Optional[int] = None
num_speculative_tokens: Optional[int] = None
speculative_disable_mqa_scorer: Optional[bool] = False
speculative_max_model_len: Optional[int] = None
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
speculative_config: Optional[Dict[str, Any]] = None
qlora_adapter_name_or_path: Optional[str] = None
show_hidden_metrics_for_version: Optional[str] = None
@@ -793,122 +778,10 @@ class EngineArgs:
help='If set, the prefill requests can be chunked based on the '
'max_num_batched_tokens.')
parser.add_argument('--speculative-config',
type=nullable_str,
type=json.loads,
default=None,
help='The configurations for speculative decoding.'
' Should be a JSON string.')
parser.add_argument(
'--speculative-model',
type=nullable_str,
default=EngineArgs.speculative_model,
help=
'The name of the draft model to be used in speculative decoding.')
# Quantization settings for speculative model.
parser.add_argument(
'--speculative-model-quantization',
type=nullable_str,
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.speculative_model_quantization,
help='Method used to quantize the weights of speculative model. '
'If None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--num-speculative-tokens',
type=int,
default=EngineArgs.num_speculative_tokens,
help='The number of speculative tokens to sample from '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-disable-mqa-scorer',
action='store_true',
help=
'If set to True, the MQA scorer will be disabled in speculative '
' and fall back to batch expansion')
parser.add_argument(
'--speculative-draft-tensor-parallel-size',
'-spec-draft-tp',
type=int,
default=EngineArgs.speculative_draft_tensor_parallel_size,
help='Number of tensor parallel replicas for '
'the draft model in speculative decoding.')
parser.add_argument(
'--speculative-max-model-len',
type=int,
default=EngineArgs.speculative_max_model_len,
help='The maximum sequence length supported by the '
'draft model. Sequences over this length will skip '
'speculation.')
parser.add_argument(
'--speculative-disable-by-batch-size',
type=int,
default=EngineArgs.speculative_disable_by_batch_size,
help='Disable speculative decoding for new incoming requests '
'if the number of enqueue requests is larger than this value.')
parser.add_argument(
'--ngram-prompt-lookup-max',
type=int,
default=EngineArgs.ngram_prompt_lookup_max,
help='Max size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--ngram-prompt-lookup-min',
type=int,
default=EngineArgs.ngram_prompt_lookup_min,
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
help='Set the lower bound threshold for the posterior '
'probability of a token to be accepted. This threshold is '
'used by the TypicalAcceptanceSampler to make sampling decisions '
'during speculative decoding.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-alpha',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
help='A scaling factor for the entropy-based threshold for token '
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument(
'--disable-logprobs-during-spec-decoding',
action=StoreBoolean,
default=EngineArgs.disable_logprobs_during_spec_decoding,
nargs="?",
const="True",
help='If set to True, token log probabilities are not returned '
'during speculative decoding. If set to False, log probabilities '
'are returned according to the settings in SamplingParams. If '
'not specified, it defaults to True. Disabling log probabilities '
'during speculative decoding reduces latency by skipping logprob '
'calculation in proposal sampling, target sampling, and after '
'accepted tokens are determined.')
parser.add_argument('--model-loader-extra-config',
type=nullable_str,
@@ -1221,58 +1094,14 @@ class EngineArgs:
This function utilizes `speculative_config` to create a
SpeculativeConfig object. The `speculative_config` can either be
provided as a JSON string input via CLI arguments or directly as a
dictionary from the engine. If `speculative_config` is not set, this
function will attempt to construct a configuration dictionary using
certain parameters, which are scheduled for deprecation in the next
release. Note that in next releases, `speculative_config` must be
provided, and the deprecated standalone speculative-related parameters
will be removed.
dictionary from the engine.
"""
if self.speculative_config is None:
if (self.speculative_model is None
and self.num_speculative_tokens is None):
return None
return None
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
# only allow '--speculative-config' after next release
logger.warning_once(
"Please use '--speculative-config' to set all configurations "
"related to speculative decoding. The current method of "
"specifying the model through '--speculative-model' and "
"adding related parameters (e.g., '--num-speculative-tokens') "
"separately will be deprecated in the next release.")
spec_config_dict = {
"model": self.speculative_model,
"quantization": self.speculative_model_quantization,
"max_model_len": self.speculative_max_model_len,
"draft_tensor_parallel_size":
self.speculative_draft_tensor_parallel_size,
"num_speculative_tokens": self.num_speculative_tokens,
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
"disable_by_batch_size":
self.speculative_disable_by_batch_size,
"prompt_lookup_max": self.ngram_prompt_lookup_max,
"prompt_lookup_min": self.ngram_prompt_lookup_min,
"acceptance_method": self.spec_decoding_acceptance_method,
"posterior_threshold":
self.typical_acceptance_sampler_posterior_threshold,
"posterior_alpha":
self.typical_acceptance_sampler_posterior_alpha,
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
}
self.speculative_config = spec_config_dict
else:
if isinstance(self.speculative_config, str):
import ast
self.speculative_config = ast.literal_eval(
self.speculative_config)
# Note(Shangming): These parameters are not obtained from the cli arg
# '--speculative-config' and must be passed in when creating the engine
# config.
assert isinstance(self.speculative_config, dict)
self.speculative_config.update({
"target_model_config": target_model_config,
"target_parallel_config": target_parallel_config,
@@ -1638,11 +1467,15 @@ class EngineArgs:
return False
# Only Ngram speculative decoding so far.
if (self.speculative_model is not None
or self.num_speculative_tokens is not None):
is_ngram_enabled = False
if self.speculative_config is not None:
# This is supported but experimental (handled below).
if self.speculative_model in ("ngram", "[ngram]"):
pass
if (("method" in self.speculative_config
and self.speculative_config["method"] in ("ngram", "[ngram]"))
or
("model" in self.speculative_config and
self.speculative_config["model"] in ("ngram", "[ngram]"))):
is_ngram_enabled = True
else:
_raise_or_fallback(feature_name="Speculative Decoding",
recommend_to_remove=False)
@@ -1691,8 +1524,7 @@ class EngineArgs:
return False
# ngram is supported on V1, but off by default for now.
if self.speculative_model in (
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
if is_ngram_enabled and _warn_or_fallback("ngram"):
return False
# Non-CUDA is supported on V1, but off by default for now.
@@ -1721,7 +1553,7 @@ class EngineArgs:
is_gpu = current_platform.is_cuda()
use_sliding_window = (model_config.get_sliding_window()
is not None)
use_spec_decode = self.speculative_model is not None
use_spec_decode = self.speculative_config is not None
if (is_gpu and not use_sliding_window and not use_spec_decode
and not self.enable_lora