35 lines
1.3 KiB
Python
35 lines
1.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import pytest
|
|
|
|
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
|
|
compute_synthetic_rejection_sampler_params,
|
|
)
|
|
|
|
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10]
|
|
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)]
|
|
|
|
|
|
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
|
|
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
|
|
"""Test that the base acceptance rate and decay factor generated for
|
|
synthetic rejection sampling have a mean joint acceptance probability
|
|
that matches the desired acceptance rate."""
|
|
tol = 1e-9
|
|
for desired_acceptance_rate in ACCEPTANCE_RATES:
|
|
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
|
|
desired_acceptance_rate, num_speculative_steps, tol=tol
|
|
)
|
|
|
|
# Compute the mean of joint acceptance probabilities across
|
|
# all speculative positions.
|
|
joint_prob = 1.0
|
|
mean_joint = 0.0
|
|
for i in range(num_speculative_steps):
|
|
joint_prob *= base_rate * decay_factor**i
|
|
mean_joint += joint_prob
|
|
mean_joint /= num_speculative_steps
|
|
|
|
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
|
|
assert base_rate <= 1.0
|