[Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling (#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-26 13:38:12 -07:00
committed by GitHub
parent 0904b6550d
commit c32e97602d
7 changed files with 240 additions and 18 deletions

View File

@@ -101,9 +101,11 @@ steps:
- vllm/v1/worker/gpu/
- vllm/v1/worker/gpu_worker.py
- tests/v1/spec_decode/test_max_len.py
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- tests/v1/e2e/spec_decode/test_spec_decode.py
commands:
- set -x
- export VLLM_USE_V2_MODEL_RUNNER=1
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"

View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
)
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10]
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)]
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
"""Test that the base acceptance rate and decay factor generated for
synthetic rejection sampling have a mean joint acceptance probability
that matches the desired acceptance rate."""
tol = 1e-9
for desired_acceptance_rate in ACCEPTANCE_RATES:
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
desired_acceptance_rate, num_speculative_steps, tol=tol
)
# Compute the mean of joint acceptance probabilities across
# all speculative positions.
joint_prob = 1.0
mean_joint = 0.0
for i in range(num_speculative_steps):
joint_prob *= base_rate * decay_factor**i
mean_joint += joint_prob
mean_joint /= num_speculative_steps
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
assert base_rate <= 1.0

View File

@@ -58,7 +58,7 @@ SpeculativeMethod = Literal[
EagleModelTypes,
NgramGPUTypes,
]
RejectionSampleMethod = Literal["strict", "probabilistic"]
RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"]
@config
@@ -184,6 +184,13 @@ class SpeculativeConfig:
distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits."""
synthetic_acceptance_rate: float | None = None
"""Average acceptance rate for synthetic rejection sampling. Draft
tokens are accepted with a position-dependent probability that decays
geometrically, calibrated so that the mean rate across all speculative
positions equals this value. Only used when rejection_sample_method
is 'synthetic'. Must be in [0, 1]."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@@ -163,12 +163,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
use_strict_rejection_sampling = False
if self.speculative_config is not None:
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
use_strict_rejection_sampling = (
self.speculative_config.rejection_sample_method == "strict"
)
if self.is_last_pp_rank:
self.speculator = init_speculator(self.vllm_config, self.device)
@@ -217,11 +213,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
if self.speculative_config is not None:
self.rejection_sampler = RejectionSampler(
self.sampler,
self.speculative_config,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
self.structured_outputs_worker = StructuredOutputsWorker(
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),

View File

@@ -85,9 +85,8 @@ class EagleSpeculator:
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
)
cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
if self.speculative_config.rejection_sample_method == "probabilistic":
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.config import SpeculativeConfig
from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
@@ -11,6 +12,10 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
synthetic_rejection_sample,
)
@triton.jit
@@ -445,12 +450,26 @@ class RejectionSampler:
def __init__(
self,
sampler: Sampler,
num_speculative_steps,
use_strict_rejection_sampling: bool = True,
spec_config: SpeculativeConfig,
):
self.sampler = sampler
self.num_speculative_steps = num_speculative_steps
self.use_strict_rejection_sampling = use_strict_rejection_sampling
self.num_speculative_steps = spec_config.num_speculative_tokens
self.rejection_sample_method = spec_config.rejection_sample_method
if self.rejection_sample_method == "synthetic":
synthetic_acceptance_rate = spec_config.synthetic_acceptance_rate
if (
synthetic_acceptance_rate is None
or not 0.0 <= synthetic_acceptance_rate <= 1.0
):
raise ValueError(
f"synthetic_acceptance_rate must be in [0, 1], "
f"but got {synthetic_acceptance_rate}"
)
self.base_acceptance_rate, self.decay_factor = (
compute_synthetic_rejection_sampler_params(
synthetic_acceptance_rate, self.num_speculative_steps
)
)
def _get_logprobs_tensors(
self,
@@ -497,7 +516,7 @@ class RejectionSampler:
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
if self.use_strict_rejection_sampling:
if self.rejection_sample_method == "strict":
sampler_output = self.sampler(logits, input_batch)
logprobs_tensors = sampler_output.logprobs_tensors
sampled, num_sampled = strict_rejection_sample(
@@ -506,7 +525,7 @@ class RejectionSampler:
input_batch.cu_num_logits,
self.num_speculative_steps,
)
else:
elif self.rejection_sample_method == "probabilistic":
assert draft_logits is not None
pos = input_batch.positions[input_batch.logits_indices]
processed_logits = self.sampler.apply_sampling_params(
@@ -538,6 +557,24 @@ class RejectionSampler:
if self.sampler.logprobs_mode == "processed_logprobs"
else logits,
)
elif self.rejection_sample_method == "synthetic":
sampler_output = self.sampler(logits, input_batch)
logprobs_tensors = sampler_output.logprobs_tensors
sampled, num_sampled = synthetic_rejection_sample(
sampler_output.sampled_token_ids.view(-1),
draft_sampled,
input_batch.cu_num_logits,
input_batch.positions[input_batch.logits_indices],
input_batch.idx_mapping,
self.sampler.sampling_states.seeds.gpu,
self.base_acceptance_rate,
self.decay_factor,
self.num_speculative_steps,
)
else:
raise ValueError(
f"Unknown rejection sample method: {self.rejection_sample_method}"
)
return SamplerOutput(
sampled_token_ids=sampled,

View File

@@ -0,0 +1,147 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.gumbel import tl_rand64
MIN_ACCEPTANCE_DECAY_FACTOR = 0.85
@triton.jit
def _synthetic_rejection_sample_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
num_sampled_ptr,
# [num_draft_tokens + num_reqs]
target_sampled_ptr,
# [num_draft_tokens + num_reqs]
input_ids_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
seeds_ptr,
base_acceptance_rate,
decay_factor,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
num_sampled = 0
acceptance_rate = base_acceptance_rate
rejected = False
for i in range(num_tokens - 1):
if not rejected:
logit_idx = start_idx + i
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
if u < acceptance_rate:
sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64)
else:
sampled = tl.load(target_sampled_ptr + logit_idx)
rejected = True
tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled)
num_sampled += 1
acceptance_rate *= decay_factor
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def synthetic_rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
# [num_logits]
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor,
base_acceptance_rate: float,
decay_factor: float,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
_synthetic_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
draft_sampled,
cu_num_logits,
pos,
idx_mapping,
seed,
base_acceptance_rate,
decay_factor,
num_warps=1,
)
return sampled, num_sampled
def compute_synthetic_rejection_sampler_params(
p_avg: float, n: int, tol: float = 1e-9
) -> tuple[float, float]:
def mean_joint_prob(a_0: float, gamma: float, n: int):
total = 0.0
for i in range(n):
total += a_0 ** (i + 1) * gamma ** (i * (i + 1) // 2)
return total / n
def min_valid_decay_factor(p: float, n: int, tol: float = 1e-9) -> float:
low, high = MIN_ACCEPTANCE_DECAY_FACTOR, 1.0
if mean_joint_prob(1, low, n) >= p:
return low
# Sweep for a gamma decay factor that is guaranteed
# to yield a base acceptance rate <= 1.
while (high - low) > tol:
mid = (low + high) / 2
if mean_joint_prob(1, mid, n) >= p:
high = mid
else:
low = mid
return high
def compute_base_acceptance_rate(
p_avg: float, gamma: float, n: int, tol: float = 1e-9
) -> float:
if p_avg <= 0.0:
return 0.0
if p_avg >= 1.0:
return 1.0
# Sweep for a base acceptance rate that yields
# the desired mean joint probability.
low, high = 0.0, 1.0
while (high - low) > tol:
mid = (low + high) / 2
if mean_joint_prob(mid, gamma, n) >= p_avg:
high = mid
else:
low = mid
return high
decay_factor = min_valid_decay_factor(p_avg, n)
base_rate = compute_base_acceptance_rate(p_avg, decay_factor, n)
return base_rate, decay_factor