[Model Runner V2] Enable forcing a specific acceptance rate during rejection sampling (#38045)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -101,9 +101,11 @@ steps:
|
||||
- vllm/v1/worker/gpu/
|
||||
- vllm/v1/worker/gpu_worker.py
|
||||
- tests/v1/spec_decode/test_max_len.py
|
||||
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- tests/v1/e2e/spec_decode/test_spec_decode.py
|
||||
commands:
|
||||
- set -x
|
||||
- export VLLM_USE_V2_MODEL_RUNNER=1
|
||||
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
|
||||
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import pytest
|
||||
|
||||
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
|
||||
compute_synthetic_rejection_sampler_params,
|
||||
)
|
||||
|
||||
NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10]
|
||||
ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS)
|
||||
def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int):
|
||||
"""Test that the base acceptance rate and decay factor generated for
|
||||
synthetic rejection sampling have a mean joint acceptance probability
|
||||
that matches the desired acceptance rate."""
|
||||
tol = 1e-9
|
||||
for desired_acceptance_rate in ACCEPTANCE_RATES:
|
||||
base_rate, decay_factor = compute_synthetic_rejection_sampler_params(
|
||||
desired_acceptance_rate, num_speculative_steps, tol=tol
|
||||
)
|
||||
|
||||
# Compute the mean of joint acceptance probabilities across
|
||||
# all speculative positions.
|
||||
joint_prob = 1.0
|
||||
mean_joint = 0.0
|
||||
for i in range(num_speculative_steps):
|
||||
joint_prob *= base_rate * decay_factor**i
|
||||
mean_joint += joint_prob
|
||||
mean_joint /= num_speculative_steps
|
||||
|
||||
assert abs(desired_acceptance_rate - mean_joint) < 10 * tol
|
||||
assert base_rate <= 1.0
|
||||
@@ -58,7 +58,7 @@ SpeculativeMethod = Literal[
|
||||
EagleModelTypes,
|
||||
NgramGPUTypes,
|
||||
]
|
||||
RejectionSampleMethod = Literal["strict", "probabilistic"]
|
||||
RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"]
|
||||
|
||||
|
||||
@config
|
||||
@@ -184,6 +184,13 @@ class SpeculativeConfig:
|
||||
distribution, but the latter yields a higher acceptance rate at the cost
|
||||
of more memory to cache draft logits."""
|
||||
|
||||
synthetic_acceptance_rate: float | None = None
|
||||
"""Average acceptance rate for synthetic rejection sampling. Draft
|
||||
tokens are accepted with a position-dependent probability that decays
|
||||
geometrically, calibrated so that the mean rate across all speculative
|
||||
positions equals this value. Only used when rejection_sample_method
|
||||
is 'synthetic'. Must be in [0, 1]."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
|
||||
@@ -163,12 +163,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.speculator = None
|
||||
self.num_speculative_steps = 0
|
||||
self.use_aux_hidden_state_outputs = False
|
||||
use_strict_rejection_sampling = False
|
||||
if self.speculative_config is not None:
|
||||
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
|
||||
use_strict_rejection_sampling = (
|
||||
self.speculative_config.rejection_sample_method == "strict"
|
||||
)
|
||||
|
||||
if self.is_last_pp_rank:
|
||||
self.speculator = init_speculator(self.vllm_config, self.device)
|
||||
@@ -217,11 +213,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
logprobs_mode=self.model_config.logprobs_mode,
|
||||
num_speculative_tokens=self.num_speculative_steps + 1,
|
||||
)
|
||||
self.rejection_sampler = RejectionSampler(
|
||||
self.sampler,
|
||||
num_speculative_steps=self.num_speculative_steps,
|
||||
use_strict_rejection_sampling=use_strict_rejection_sampling,
|
||||
)
|
||||
if self.speculative_config is not None:
|
||||
self.rejection_sampler = RejectionSampler(
|
||||
self.sampler,
|
||||
self.speculative_config,
|
||||
)
|
||||
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
|
||||
self.structured_outputs_worker = StructuredOutputsWorker(
|
||||
max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1),
|
||||
|
||||
@@ -85,9 +85,8 @@ class EagleSpeculator:
|
||||
self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device
|
||||
)
|
||||
|
||||
cache_draft_logits = self.speculative_config.rejection_sample_method != "strict"
|
||||
self.draft_logits: torch.Tensor | None = None
|
||||
if cache_draft_logits:
|
||||
if self.speculative_config.rejection_sample_method == "probabilistic":
|
||||
self.draft_logits = torch.zeros(
|
||||
self.max_num_reqs,
|
||||
self.num_speculative_steps,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.config import SpeculativeConfig
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
@@ -11,6 +12,10 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
|
||||
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
|
||||
compute_synthetic_rejection_sampler_params,
|
||||
synthetic_rejection_sample,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -445,12 +450,26 @@ class RejectionSampler:
|
||||
def __init__(
|
||||
self,
|
||||
sampler: Sampler,
|
||||
num_speculative_steps,
|
||||
use_strict_rejection_sampling: bool = True,
|
||||
spec_config: SpeculativeConfig,
|
||||
):
|
||||
self.sampler = sampler
|
||||
self.num_speculative_steps = num_speculative_steps
|
||||
self.use_strict_rejection_sampling = use_strict_rejection_sampling
|
||||
self.num_speculative_steps = spec_config.num_speculative_tokens
|
||||
self.rejection_sample_method = spec_config.rejection_sample_method
|
||||
if self.rejection_sample_method == "synthetic":
|
||||
synthetic_acceptance_rate = spec_config.synthetic_acceptance_rate
|
||||
if (
|
||||
synthetic_acceptance_rate is None
|
||||
or not 0.0 <= synthetic_acceptance_rate <= 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
f"synthetic_acceptance_rate must be in [0, 1], "
|
||||
f"but got {synthetic_acceptance_rate}"
|
||||
)
|
||||
self.base_acceptance_rate, self.decay_factor = (
|
||||
compute_synthetic_rejection_sampler_params(
|
||||
synthetic_acceptance_rate, self.num_speculative_steps
|
||||
)
|
||||
)
|
||||
|
||||
def _get_logprobs_tensors(
|
||||
self,
|
||||
@@ -497,7 +516,7 @@ class RejectionSampler:
|
||||
# that num_nans is computed before applying penalties and temperature.
|
||||
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
|
||||
|
||||
if self.use_strict_rejection_sampling:
|
||||
if self.rejection_sample_method == "strict":
|
||||
sampler_output = self.sampler(logits, input_batch)
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
sampled, num_sampled = strict_rejection_sample(
|
||||
@@ -506,7 +525,7 @@ class RejectionSampler:
|
||||
input_batch.cu_num_logits,
|
||||
self.num_speculative_steps,
|
||||
)
|
||||
else:
|
||||
elif self.rejection_sample_method == "probabilistic":
|
||||
assert draft_logits is not None
|
||||
pos = input_batch.positions[input_batch.logits_indices]
|
||||
processed_logits = self.sampler.apply_sampling_params(
|
||||
@@ -538,6 +557,24 @@ class RejectionSampler:
|
||||
if self.sampler.logprobs_mode == "processed_logprobs"
|
||||
else logits,
|
||||
)
|
||||
elif self.rejection_sample_method == "synthetic":
|
||||
sampler_output = self.sampler(logits, input_batch)
|
||||
logprobs_tensors = sampler_output.logprobs_tensors
|
||||
sampled, num_sampled = synthetic_rejection_sample(
|
||||
sampler_output.sampled_token_ids.view(-1),
|
||||
draft_sampled,
|
||||
input_batch.cu_num_logits,
|
||||
input_batch.positions[input_batch.logits_indices],
|
||||
input_batch.idx_mapping,
|
||||
self.sampler.sampling_states.seeds.gpu,
|
||||
self.base_acceptance_rate,
|
||||
self.decay_factor,
|
||||
self.num_speculative_steps,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown rejection sample method: {self.rejection_sample_method}"
|
||||
)
|
||||
|
||||
return SamplerOutput(
|
||||
sampled_token_ids=sampled,
|
||||
|
||||
@@ -0,0 +1,147 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.sample.gumbel import tl_rand64
|
||||
|
||||
MIN_ACCEPTANCE_DECAY_FACTOR = 0.85
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _synthetic_rejection_sample_kernel(
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_ptr,
|
||||
sampled_stride,
|
||||
# [num_reqs]
|
||||
num_sampled_ptr,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
target_sampled_ptr,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
input_ids_ptr,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits_ptr,
|
||||
# [num_logits]
|
||||
pos_ptr,
|
||||
# [num_reqs]
|
||||
idx_mapping_ptr,
|
||||
# [max_num_reqs]
|
||||
seeds_ptr,
|
||||
base_acceptance_rate,
|
||||
decay_factor,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
|
||||
num_sampled = 0
|
||||
acceptance_rate = base_acceptance_rate
|
||||
rejected = False
|
||||
for i in range(num_tokens - 1):
|
||||
if not rejected:
|
||||
logit_idx = start_idx + i
|
||||
pos = tl.load(pos_ptr + logit_idx)
|
||||
u = tl_rand64(seed, pos, includes_zero=False)
|
||||
if u < acceptance_rate:
|
||||
sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64)
|
||||
else:
|
||||
sampled = tl.load(target_sampled_ptr + logit_idx)
|
||||
rejected = True
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled)
|
||||
num_sampled += 1
|
||||
acceptance_rate *= decay_factor
|
||||
if not rejected:
|
||||
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
|
||||
tl.store(
|
||||
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
|
||||
)
|
||||
num_sampled += 1
|
||||
tl.store(num_sampled_ptr + req_idx, num_sampled)
|
||||
|
||||
|
||||
def synthetic_rejection_sample(
|
||||
# [num_draft_tokens + num_reqs]
|
||||
target_sampled: torch.Tensor,
|
||||
# [num_draft_tokens + num_reqs]
|
||||
draft_sampled: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
# [num_logits]
|
||||
pos: torch.Tensor,
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seed: torch.Tensor,
|
||||
base_acceptance_rate: float,
|
||||
decay_factor: float,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
|
||||
num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
|
||||
_synthetic_rejection_sample_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
target_sampled,
|
||||
draft_sampled,
|
||||
cu_num_logits,
|
||||
pos,
|
||||
idx_mapping,
|
||||
seed,
|
||||
base_acceptance_rate,
|
||||
decay_factor,
|
||||
num_warps=1,
|
||||
)
|
||||
return sampled, num_sampled
|
||||
|
||||
|
||||
def compute_synthetic_rejection_sampler_params(
|
||||
p_avg: float, n: int, tol: float = 1e-9
|
||||
) -> tuple[float, float]:
|
||||
def mean_joint_prob(a_0: float, gamma: float, n: int):
|
||||
total = 0.0
|
||||
for i in range(n):
|
||||
total += a_0 ** (i + 1) * gamma ** (i * (i + 1) // 2)
|
||||
return total / n
|
||||
|
||||
def min_valid_decay_factor(p: float, n: int, tol: float = 1e-9) -> float:
|
||||
low, high = MIN_ACCEPTANCE_DECAY_FACTOR, 1.0
|
||||
if mean_joint_prob(1, low, n) >= p:
|
||||
return low
|
||||
|
||||
# Sweep for a gamma decay factor that is guaranteed
|
||||
# to yield a base acceptance rate <= 1.
|
||||
while (high - low) > tol:
|
||||
mid = (low + high) / 2
|
||||
if mean_joint_prob(1, mid, n) >= p:
|
||||
high = mid
|
||||
else:
|
||||
low = mid
|
||||
return high
|
||||
|
||||
def compute_base_acceptance_rate(
|
||||
p_avg: float, gamma: float, n: int, tol: float = 1e-9
|
||||
) -> float:
|
||||
if p_avg <= 0.0:
|
||||
return 0.0
|
||||
if p_avg >= 1.0:
|
||||
return 1.0
|
||||
|
||||
# Sweep for a base acceptance rate that yields
|
||||
# the desired mean joint probability.
|
||||
low, high = 0.0, 1.0
|
||||
while (high - low) > tol:
|
||||
mid = (low + high) / 2
|
||||
if mean_joint_prob(mid, gamma, n) >= p_avg:
|
||||
high = mid
|
||||
else:
|
||||
low = mid
|
||||
return high
|
||||
|
||||
decay_factor = min_valid_decay_factor(p_avg, n)
|
||||
base_rate = compute_base_acceptance_rate(p_avg, decay_factor, n)
|
||||
return base_rate, decay_factor
|
||||
Reference in New Issue
Block a user