[Speculative Decoding 2/2 ] Integrate typical acceptance sampler into Spec Decode Worker (#5348)
This commit is contained in:
@@ -753,7 +753,6 @@ class SchedulerConfig:
|
||||
self.chunked_prefill_enabled = enable_chunked_prefill
|
||||
self.embedding_mode = embedding_mode
|
||||
self.preemption_mode = preemption_mode
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
@@ -834,6 +833,9 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: Optional[float],
|
||||
typical_acceptance_sampler_posterior_alpha: Optional[float],
|
||||
) -> Optional["SpeculativeConfig"]:
|
||||
"""Create a SpeculativeConfig if possible, else return None.
|
||||
|
||||
@@ -870,7 +872,20 @@ class SpeculativeConfig:
|
||||
window, if provided.
|
||||
ngram_prompt_lookup_min (Optional[int]): Min size of ngram token
|
||||
window, if provided.
|
||||
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
|
||||
Returns:
|
||||
Optional["SpeculativeConfig"]: An instance of SpeculativeConfig if
|
||||
the necessary conditions are met, else None.
|
||||
@@ -984,6 +999,11 @@ class SpeculativeConfig:
|
||||
"speculative_model unless the draft model config contains an "
|
||||
"n_predict parameter.")
|
||||
|
||||
if typical_acceptance_sampler_posterior_threshold is None:
|
||||
typical_acceptance_sampler_posterior_threshold = 0.09
|
||||
if typical_acceptance_sampler_posterior_alpha is None:
|
||||
typical_acceptance_sampler_posterior_alpha = 0.3
|
||||
|
||||
return SpeculativeConfig(
|
||||
draft_model_config,
|
||||
draft_parallel_config,
|
||||
@@ -991,6 +1011,11 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size,
|
||||
ngram_prompt_lookup_max,
|
||||
ngram_prompt_lookup_min,
|
||||
draft_token_acceptance_method=draft_token_acceptance_method,
|
||||
typical_acceptance_sampler_posterior_threshold=\
|
||||
typical_acceptance_sampler_posterior_threshold,
|
||||
typical_acceptance_sampler_posterior_alpha=\
|
||||
typical_acceptance_sampler_posterior_alpha,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1072,6 +1097,9 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size: Optional[int],
|
||||
ngram_prompt_lookup_max: Optional[int],
|
||||
ngram_prompt_lookup_min: Optional[int],
|
||||
draft_token_acceptance_method: str,
|
||||
typical_acceptance_sampler_posterior_threshold: float,
|
||||
typical_acceptance_sampler_posterior_alpha: float,
|
||||
):
|
||||
"""Create a SpeculativeConfig object.
|
||||
|
||||
@@ -1085,6 +1113,19 @@ class SpeculativeConfig:
|
||||
enqueue requests is larger than this value.
|
||||
ngram_prompt_lookup_max: Max size of ngram token window.
|
||||
ngram_prompt_lookup_min: Min size of ngram token window.
|
||||
draft_token_acceptance_method (str): The method to use for
|
||||
accepting draft tokens. This can take two possible
|
||||
values 'rejection_sampler' and 'typical_acceptance_sampler'
|
||||
for RejectionSampler and TypicalAcceptanceSampler
|
||||
respectively.
|
||||
typical_acceptance_sampler_posterior_threshold (Optional[float]):
|
||||
A threshold value that sets a lower bound on the posterior
|
||||
probability of a token in the target model for it to be
|
||||
accepted. This threshold is used only when we use the
|
||||
TypicalAcceptanceSampler for token acceptance.
|
||||
typical_acceptance_sampler_posterior_alpha (Optional[float]):
|
||||
A scaling factor for the entropy-based threshold in the
|
||||
TypicalAcceptanceSampler.
|
||||
"""
|
||||
self.draft_model_config = draft_model_config
|
||||
self.draft_parallel_config = draft_parallel_config
|
||||
@@ -1093,6 +1134,11 @@ class SpeculativeConfig:
|
||||
speculative_disable_by_batch_size
|
||||
self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0
|
||||
self.ngram_prompt_lookup_min = ngram_prompt_lookup_min or 0
|
||||
self.draft_token_acceptance_method = draft_token_acceptance_method
|
||||
self.typical_acceptance_sampler_posterior_threshold = \
|
||||
typical_acceptance_sampler_posterior_threshold
|
||||
self.typical_acceptance_sampler_posterior_alpha = \
|
||||
typical_acceptance_sampler_posterior_alpha
|
||||
|
||||
self._verify_args()
|
||||
|
||||
@@ -1104,6 +1150,31 @@ class SpeculativeConfig:
|
||||
if self.draft_model_config:
|
||||
self.draft_model_config.verify_with_parallel_config(
|
||||
self.draft_parallel_config)
|
||||
# Validate and set draft token acceptance related settings.
|
||||
|
||||
if (self.draft_token_acceptance_method is None):
|
||||
raise ValueError("draft_token_acceptance_method is not set. "
|
||||
"Expected values are rejection_sampler or "
|
||||
"typical_acceptance_sampler.")
|
||||
|
||||
if (self.draft_token_acceptance_method != 'rejection_sampler'
|
||||
and self.draft_token_acceptance_method !=
|
||||
'typical_acceptance_sampler'):
|
||||
raise ValueError(
|
||||
"Expected draft_token_acceptance_method to be either "
|
||||
"rejection_sampler or typical_acceptance_sampler. Instead it "
|
||||
f"is {self.draft_token_acceptance_method}")
|
||||
|
||||
if (self.typical_acceptance_sampler_posterior_threshold < 0
|
||||
or self.typical_acceptance_sampler_posterior_alpha < 0):
|
||||
raise ValueError(
|
||||
"Expected typical_acceptance_sampler_posterior_threshold "
|
||||
"and typical_acceptance_sampler_posterior_alpha to be > 0. "
|
||||
"Instead found "
|
||||
f"typical_acceptance_sampler_posterior_threshold = "
|
||||
f"{self.typical_acceptance_sampler_posterior_threshold} and "
|
||||
f"typical_acceptance_sampler_posterior_alpha = "
|
||||
f"{self.typical_acceptance_sampler_posterior_alpha}")
|
||||
|
||||
@property
|
||||
def num_lookahead_slots(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user