[Model Runner V2] Add probabilistic rejection sampling for spec decoding (#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-11 14:04:32 -07:00
committed by GitHub
parent 12001f2ebc
commit c77181e534
9 changed files with 494 additions and 112 deletions

View File

@@ -57,6 +57,10 @@ SpeculativeMethod = Literal[
EagleModelTypes,
NgramGPUTypes,
]
RejectionSampleMethod = Literal[
"strict",
"probabilistic",
]
@config
@@ -171,6 +175,12 @@ class SpeculativeConfig:
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
rejection_sample_method: RejectionSampleMethod = "strict"
"""Whether to use strict (target and draft sampled tokens match exactly)
or probabilistic rejection sampling. Both respect the target model
distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,