[Model Runner V2] Add probabilistic rejection sampling for spec decoding (#35461)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -57,6 +57,10 @@ SpeculativeMethod = Literal[
|
||||
EagleModelTypes,
|
||||
NgramGPUTypes,
|
||||
]
|
||||
RejectionSampleMethod = Literal[
|
||||
"strict",
|
||||
"probabilistic",
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@@ -171,6 +175,12 @@ class SpeculativeConfig:
|
||||
"""Load config for the draft model. If not specified, will use the load
|
||||
config from the target model."""
|
||||
|
||||
rejection_sample_method: RejectionSampleMethod = "strict"
|
||||
"""Whether to use strict (target and draft sampled tokens match exactly)
|
||||
or probabilistic rejection sampling. Both respect the target model
|
||||
distribution, but the latter yields a higher acceptance rate at the cost
|
||||
of more memory to cache draft logits."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
Reference in New Issue
Block a user