[Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling (#38045)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-26 13:38:12 -07:00
committed by GitHub
parent 0904b6550d
commit c32e97602d
7 changed files with 240 additions and 18 deletions

View File

@@ -0,0 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
)
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10]
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)]
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
"""Test that the base acceptance rate and decay factor generated for
synthetic rejection sampling have a mean joint acceptance probability
that matches the desired acceptance rate."""
tol = 1e-9
for desired_acceptance_rate in ACCEPTANCE_RATES:
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
desired_acceptance_rate, num_speculative_steps, tol=tol
)
# Compute the mean of joint acceptance probabilities across
# all speculative positions.
joint_prob = 1.0
mean_joint = 0.0
for i in range(num_speculative_steps):
joint_prob *= base_rate * decay_factor**i
mean_joint += joint_prob
mean_joint /= num_speculative_steps
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
assert base_rate <= 1.0