[V1][Usage] Refactor speculative decoding configuration and tests (#14434)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -177,7 +177,10 @@ class EngineArgs:
|
||||
|
||||
guided_decoding_backend: str = 'xgrammar'
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
# Speculative decoding configuration.
|
||||
|
||||
speculative_config: Optional[Union[str, Dict[str, Any]]] = None
|
||||
|
||||
# TODO(Shangming): Deprecate these out-of-date params after next release
|
||||
speculative_model: Optional[str] = None
|
||||
speculative_model_quantization: Optional[str] = None
|
||||
speculative_draft_tensor_parallel_size: Optional[int] = None
|
||||
@@ -190,9 +193,9 @@ class EngineArgs:
|
||||
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
disable_logprobs_during_spec_decoding: Optional[bool] = None
|
||||
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
show_hidden_metrics_for_version: Optional[str] = None
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
@@ -780,7 +783,11 @@ class EngineArgs:
|
||||
const="True",
|
||||
help='If set, the prefill requests can be chunked based on the '
|
||||
'max_num_batched_tokens.')
|
||||
|
||||
parser.add_argument('--speculative-config',
|
||||
type=nullable_str,
|
||||
default=None,
|
||||
help='The configurations for speculative decoding.'
|
||||
' Should be a JSON string.')
|
||||
parser.add_argument(
|
||||
'--speculative-model',
|
||||
type=nullable_str,
|
||||
@@ -1192,6 +1199,82 @@ class EngineArgs:
|
||||
use_tqdm_on_load=self.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
def create_speculative_config(
|
||||
self,
|
||||
target_model_config: ModelConfig,
|
||||
target_parallel_config: ParallelConfig,
|
||||
enable_chunked_prefill: bool,
|
||||
disable_log_stats: bool,
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Initializes and returns a SpeculativeConfig object based on
|
||||
`speculative_config`.
|
||||
|
||||
This function utilizes `speculative_config` to create a
|
||||
SpeculativeConfig object. The `speculative_config` can either be
|
||||
provided as a JSON string input via CLI arguments or directly as a
|
||||
dictionary from the engine. If `speculative_config` is not set, this
|
||||
function will attempt to construct a configuration dictionary using
|
||||
certain parameters, which are scheduled for deprecation in the next
|
||||
release. Note that in next releases, `speculative_config` must be
|
||||
provided, and the deprecated standalone speculative-related parameters
|
||||
will be removed.
|
||||
"""
|
||||
if self.speculative_config is None:
|
||||
if (self.speculative_model is None
|
||||
and self.num_speculative_tokens is None):
|
||||
return None
|
||||
|
||||
# TODO(Shangming): Deprecate this way of setting SpeculativeConfig,
|
||||
# only allow '--speculative-config' after next release
|
||||
logger.warning_once(
|
||||
"Please use '--speculative-config' to set all configurations "
|
||||
"related to speculative decoding. The current method of "
|
||||
"specifying the model through '--speculative-model' and "
|
||||
"adding related parameters (e.g., '--num-speculative-tokens') "
|
||||
"separately will be deprecated in the next release.")
|
||||
|
||||
spec_config_dict = {
|
||||
"model": self.speculative_model,
|
||||
"quantization": self.speculative_model_quantization,
|
||||
"max_model_len": self.speculative_max_model_len,
|
||||
"draft_tensor_parallel_size":
|
||||
self.speculative_draft_tensor_parallel_size,
|
||||
"num_speculative_tokens": self.num_speculative_tokens,
|
||||
"disable_mqa_scorer": self.speculative_disable_mqa_scorer,
|
||||
"disable_by_batch_size":
|
||||
self.speculative_disable_by_batch_size,
|
||||
"prompt_lookup_max": self.ngram_prompt_lookup_max,
|
||||
"prompt_lookup_min": self.ngram_prompt_lookup_min,
|
||||
"acceptance_method": self.spec_decoding_acceptance_method,
|
||||
"posterior_threshold":
|
||||
self.typical_acceptance_sampler_posterior_threshold,
|
||||
"posterior_alpha":
|
||||
self.typical_acceptance_sampler_posterior_alpha,
|
||||
"disable_logprobs": self.disable_logprobs_during_spec_decoding,
|
||||
}
|
||||
|
||||
self.speculative_config = spec_config_dict
|
||||
else:
|
||||
if isinstance(self.speculative_config, str):
|
||||
import ast
|
||||
self.speculative_config = ast.literal_eval(
|
||||
self.speculative_config)
|
||||
# Note(Shangming): These parameters are not obtained from the cli arg
|
||||
# '--speculative-config' and must be passed in when creating the engine
|
||||
# config.
|
||||
|
||||
assert isinstance(self.speculative_config, dict)
|
||||
self.speculative_config.update({
|
||||
"target_model_config": target_model_config,
|
||||
"target_parallel_config": target_parallel_config,
|
||||
"enable_chunked_prefill": enable_chunked_prefill,
|
||||
"disable_log_stats": disable_log_stats,
|
||||
})
|
||||
speculative_config = SpeculativeConfig.from_dict(
|
||||
self.speculative_config)
|
||||
|
||||
return speculative_config
|
||||
|
||||
def create_engine_config(
|
||||
self,
|
||||
usage_context: Optional[UsageContext] = None,
|
||||
@@ -1238,6 +1321,8 @@ class EngineArgs:
|
||||
else:
|
||||
self._set_default_args_v0(model_config)
|
||||
|
||||
assert self.enable_chunked_prefill is not None
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
@@ -1280,31 +1365,11 @@ class EngineArgs:
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
)
|
||||
|
||||
speculative_config = SpeculativeConfig.maybe_create_spec_config(
|
||||
speculative_config = self.create_speculative_config(
|
||||
target_model_config=model_config,
|
||||
target_parallel_config=parallel_config,
|
||||
target_dtype=self.dtype,
|
||||
speculative_model=self.speculative_model,
|
||||
speculative_model_quantization = \
|
||||
self.speculative_model_quantization,
|
||||
speculative_draft_tensor_parallel_size = \
|
||||
self.speculative_draft_tensor_parallel_size,
|
||||
num_speculative_tokens=self.num_speculative_tokens,
|
||||
speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer,
|
||||
speculative_disable_by_batch_size=self.
|
||||
speculative_disable_by_batch_size,
|
||||
speculative_max_model_len=self.speculative_max_model_len,
|
||||
enable_chunked_prefill=self.enable_chunked_prefill,
|
||||
disable_log_stats=self.disable_log_stats,
|
||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=\
|
||||
self.spec_decoding_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=self.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=self.
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
disable_logprobs=self.disable_logprobs_during_spec_decoding,
|
||||
)
|
||||
|
||||
# Reminder: Please update docs/source/features/compatibility_matrix.md
|
||||
@@ -1569,7 +1634,7 @@ class EngineArgs:
|
||||
if (self.speculative_model is not None
|
||||
or self.num_speculative_tokens is not None):
|
||||
# This is supported but experimental (handled below).
|
||||
if self.speculative_model == "[ngram]":
|
||||
if self.speculative_model in ("ngram", "[ngram]"):
|
||||
pass
|
||||
else:
|
||||
_raise_or_fallback(feature_name="Speculative Decoding",
|
||||
@@ -1617,7 +1682,8 @@ class EngineArgs:
|
||||
return False
|
||||
|
||||
# ngram is supported on V1, but off by default for now.
|
||||
if self.speculative_model == "[ngram]" and _warn_or_fallback("ngram"):
|
||||
if self.speculative_model in (
|
||||
"ngram", "[ngram]") and _warn_or_fallback("ngram"):
|
||||
return False
|
||||
|
||||
# Non-CUDA is supported on V1, but off by default for now.
|
||||
|
||||
Reference in New Issue
Block a user