[Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling (#38045)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
|
||||
compute_synthetic_rejection_sampler_params,
|
||||
)
|
||||
|
||||
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10]
|
||||
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
|
||||
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
|
||||
"""Test that the base acceptance rate and decay factor generated for
|
||||
synthetic rejection sampling have a mean joint acceptance probability
|
||||
that matches the desired acceptance rate."""
|
||||
tol = 1e-9
|
||||
for desired_acceptance_rate in ACCEPTANCE_RATES:
|
||||
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
|
||||
desired_acceptance_rate, num_speculative_steps, tol=tol
|
||||
)
|
||||
|
||||
# Compute the mean of joint acceptance probabilities across
|
||||
# all speculative positions.
|
||||
joint_prob = 1.0
|
||||
mean_joint = 0.0
|
||||
for i in range(num_speculative_steps):
|
||||
joint_prob *= base_rate * decay_factor**i
|
||||
mean_joint += joint_prob
|
||||
mean_joint /= num_speculative_steps
|
||||
|
||||
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
|
||||
assert base_rate <= 1.0
|
||||
Reference in New Issue
Block a user