[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
@@ -100,7 +100,9 @@ class EngineArgs:
|
||||
speculative_disable_by_batch_size: Optional[int] = None
|
||||
ngram_prompt_lookup_max: Optional[int] = None
|
||||
ngram_prompt_lookup_min: Optional[int] = None
|
||||
|
||||
spec_decoding_acceptance_method: str = 'rejection_sampler'
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
|
||||
qlora_adapter_name_or_path: Optional[str] = None
|
||||
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
@@ -577,6 +579,38 @@ class EngineArgs:
|
||||
help='Min size of window for ngram prompt lookup in speculative '
|
||||
'decoding.')
|
||||
|
||||
parser.add_argument(
|
||||
'--spec-decoding-acceptance-method',
|
||||
type=str,
|
||||
default=EngineArgs.spec_decoding_acceptance_method,
|
||||
choices=['rejection_sampler', 'typical_acceptance_sampler'],
|
||||
help='Specify the acceptance method to use during draft token '
|
||||
'verification in speculative decoding. Two types of acceptance '
|
||||
'routines are supported: '
|
||||
'1) RejectionSampler which does not allow changing the '
|
||||
'acceptance rate of draft tokens, '
|
||||
'2) TypicalAcceptanceSampler which is configurable, allowing for '
|
||||
'a higher acceptance rate at the cost of lower quality, '
|
||||
'and vice versa.')
|
||||
|
||||
parser.add_argument(
|
||||
'--typical-acceptance-sampler-posterior-threshold',
|
||||
type=float,
|
||||
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
|
||||
help='Set the lower bound threshold for the posterior '
|
||||
'probability of a token to be accepted. This threshold is '
|
||||
'used by the TypicalAcceptanceSampler to make sampling decisions '
|
||||
'during speculative decoding. Defaults to 0.09')
|
||||
|
||||
parser.add_argument(
|
||||
'--typical-acceptance-sampler-posterior-alpha',
|
||||
type=float,
|
||||
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
|
||||
help='A scaling factor for the entropy-based threshold for token '
|
||||
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
|
||||
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
|
||||
'i.e. 0.3')
|
||||
|
||||
parser.add_argument('--model-loader-extra-config',
|
||||
type=nullable_str,
|
||||
default=EngineArgs.model_loader_extra_config,
|
||||
@@ -737,6 +771,12 @@ class EngineArgs:
|
||||
use_v2_block_manager=self.use_v2_block_manager,
|
||||
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=\
|
||||
self.spec_decoding_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=self.
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=self.
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
|
||||
scheduler_config = SchedulerConfig(
|
||||
|
||||
Reference in New Issue
Block a user