[Spec Decode] Disable Log Prob serialization to CPU for spec decoding for both draft and target models. (#6485)

This commit is contained in:
sroy745
2024-07-20 23:58:58 -07:00
committed by GitHub
parent d7f4178dd9
commit 14f91fe67c
8 changed files with 333 additions and 64 deletions

View File

@@ -110,6 +110,7 @@ class EngineArgs:
typical_acceptance_sampler_posterior_threshold: Optional[float] = None
typical_acceptance_sampler_posterior_alpha: Optional[float] = None
qlora_adapter_name_or_path: Optional[str] = None
disable_logprobs_during_spec_decoding: Optional[bool] = None
otlp_traces_endpoint: Optional[str] = None
@@ -592,6 +593,18 @@ class EngineArgs:
'to sqrt of --typical-acceptance-sampler-posterior-threshold '
'i.e. 0.3')
parser.add_argument(
'--disable-logprobs-during-spec-decoding',
type=bool,
default=EngineArgs.disable_logprobs_during_spec_decoding,
help='If set to True, token log probabilities are not returned '
'during speculative decoding. If set to False, log probabilities '
'are returned according to the settings in SamplingParams. If '
'not specified, it defaults to True. Disabling log probabilities '
'during speculative decoding reduces latency by skipping logprob '
'calculation in proposal sampling, target sampling, and after '
'accepted tokens are determined.')
parser.add_argument('--model-loader-extra-config',
type=nullable_str,
default=EngineArgs.model_loader_extra_config,
@@ -736,6 +749,7 @@ class EngineArgs:
typical_acceptance_sampler_posterior_threshold,
typical_acceptance_sampler_posterior_alpha=self.
typical_acceptance_sampler_posterior_alpha,
disable_logprobs=self.disable_logprobs_during_spec_decoding,
)
scheduler_config = SchedulerConfig(