[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)

This commit is contained in:
sroy745
2024-07-01 00:33:05 -07:00
committed by GitHub
parent 614aa51203
commit 80ca1e6a3a
14 changed files with 480 additions and 208 deletions

View File

@@ -100,7 +100,9 @@ class EngineArgs:
speculative_disable_by_batch_size: Optional[int] = None
ngram_prompt_lookup_max: Optional[int] = None
ngram_prompt_lookup_min: Optional[int] = None
spec_decoding_acceptance_method: str = 'rejection_sampler'
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None
@@ -577,6 +579,38 @@ class EngineArgs:
help='Min size of window for ngram prompt lookup in speculative '
'decoding.')
parser.add_argument(
'--spec-decoding-acceptance-method',
type=str,
default=EngineArgs.spec_decoding_acceptance_method,
choices=['rejection_sampler', 'typical_acceptance_sampler'],
help='Specify the acceptance method to use during draft token '
'verification in speculative decoding. Two types of acceptance '
'routines are supported: '
'1) RejectionSampler which does not allow changing the '
'acceptance rate of draft tokens, '
'2) TypicalAcceptanceSampler which is configurable, allowing for '
'a higher acceptance rate at the cost of lower quality, '
'and vice versa.')
parser.add_argument(
'--typical-acceptance-sampler-posterior-threshold',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_threshold,
help='Set the lower bound threshold for the posterior '
'probability of a token to be accepted. This threshold is '
'used by the TypicalAcceptanceSampler to make sampling decisions '
'during speculative decoding. Defaults to 0.09')
parser.add_argument(
'--typical-acceptance-sampler-posterior-alpha',
type=float,
default=EngineArgs.typical_acceptance_sampler_posterior_alpha,
help='A scaling factor for the entropy-based threshold for token '
'acceptance in the TypicalAcceptanceSampler. Typically defaults '
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument('--model-loader-extra-config',
type=nullable_str,
default=EngineArgs.model_loader_extra_config,
@@ -737,6 +771,12 @@ class EngineArgs:
use_v2_block_manager=self.use_v2_block_manager,
ngram_prompt_lookup_max=self.ngram_prompt_lookup_max,
ngram_prompt_lookup_min=self.ngram_prompt_lookup_min,
draft_token_acceptance_method=\
self.spec_decoding_acceptance_method,
typical_acceptance_sampler_posterior_threshold=self.
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
)
scheduler_config = SchedulerConfig(