Files
vllm/tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

35 lines
1.3 KiB
Python
Raw Normal View History

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
)
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10]
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)]
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
"""Test that the base acceptance rate and decay factor generated for
synthetic rejection sampling have a mean joint acceptance probability
that matches the desired acceptance rate."""
tol = 1e-9
for desired_acceptance_rate in ACCEPTANCE_RATES:
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
desired_acceptance_rate, num_speculative_steps, tol=tol
)
# Compute the mean of joint acceptance probabilities across
# all speculative positions.
joint_prob = 1.0
mean_joint = 0.0
for i in range(num_speculative_steps):
joint_prob *= base_rate * decay_factor**i
mean_joint += joint_prob
mean_joint /= num_speculative_steps
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
assert base_rate <= 1.0