diff --git a/.buildkite/test_areas/model_runner_v2.yaml b/.buildkite/test_areas/model_runner_v2.yaml index dd64a0d23..b39b00d0c 100644 --- a/.buildkite/test_areas/model_runner_v2.yaml +++ b/.buildkite/test_areas/model_runner_v2.yaml @@ -101,9 +101,11 @@ steps: - vllm/v1/worker/gpu/ - vllm/v1/worker/gpu_worker.py - tests/v1/spec_decode/test_max_len.py + - tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py - tests/v1/e2e/spec_decode/test_spec_decode.py commands: - set -x - export VLLM_USE_V2_MODEL_RUNNER=1 - pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp" + - pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py - pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp" diff --git a/tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py b/tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py new file mode 100644 index 000000000..d817bc1b8 --- /dev/null +++ b/tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import pytest + +from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import ( + compute_synthetic_rejection_sampler_params, +) + +NUM_SPECULATIVE_STEPS = [1, 2, 3, 4, 5, 7, 10] +ACCEPTANCE_RATES = [i / 100 for i in range(0, 100)] + + +@pytest.mark.parametrize("num_speculative_steps", NUM_SPECULATIVE_STEPS) +def test_compute_synthetic_rejection_sampler_params(num_speculative_steps: int): + """Test that the base acceptance rate and decay factor generated for + synthetic rejection sampling have a mean joint acceptance probability + that matches the desired acceptance rate.""" + tol = 1e-9 + for desired_acceptance_rate in ACCEPTANCE_RATES: + base_rate, decay_factor = compute_synthetic_rejection_sampler_params( + desired_acceptance_rate, num_speculative_steps, tol=tol + ) + + # Compute the mean of joint acceptance probabilities across + # all speculative positions. + joint_prob = 1.0 + mean_joint = 0.0 + for i in range(num_speculative_steps): + joint_prob *= base_rate * decay_factor**i + mean_joint += joint_prob + mean_joint /= num_speculative_steps + + assert abs(desired_acceptance_rate - mean_joint) < 10 * tol + assert base_rate <= 1.0 diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 8c81b36a8..375793941 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -58,7 +58,7 @@ SpeculativeMethod = Literal[ EagleModelTypes, NgramGPUTypes, ] -RejectionSampleMethod = Literal["strict", "probabilistic"] +RejectionSampleMethod = Literal["strict", "probabilistic", "synthetic"] @config @@ -184,6 +184,13 @@ class SpeculativeConfig: distribution, but the latter yields a higher acceptance rate at the cost of more memory to cache draft logits.""" + synthetic_acceptance_rate: float | None = None + """Average acceptance rate for synthetic rejection sampling. Draft + tokens are accepted with a position-dependent probability that decays + geometrically, calibrated so that the mean rate across all speculative + positions equals this value. Only used when rejection_sample_method + is 'synthetic'. Must be in [0, 1].""" + def compute_hash(self) -> str: """ WARNING: Whenever a new field is added to this config, diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index ea92e5aea..a2f83c52e 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -163,12 +163,8 @@ class GPUModelRunner(LoRAModelRunnerMixin): self.speculator = None self.num_speculative_steps = 0 self.use_aux_hidden_state_outputs = False - use_strict_rejection_sampling = False if self.speculative_config is not None: self.num_speculative_steps = self.speculative_config.num_speculative_tokens - use_strict_rejection_sampling = ( - self.speculative_config.rejection_sample_method == "strict" - ) if self.is_last_pp_rank: self.speculator = init_speculator(self.vllm_config, self.device) @@ -217,11 +213,11 @@ class GPUModelRunner(LoRAModelRunnerMixin): logprobs_mode=self.model_config.logprobs_mode, num_speculative_tokens=self.num_speculative_steps + 1, ) - self.rejection_sampler = RejectionSampler( - self.sampler, - num_speculative_steps=self.num_speculative_steps, - use_strict_rejection_sampling=use_strict_rejection_sampling, - ) + if self.speculative_config is not None: + self.rejection_sampler = RejectionSampler( + self.sampler, + self.speculative_config, + ) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) self.structured_outputs_worker = StructuredOutputsWorker( max_num_logits=self.max_num_reqs * (self.num_speculative_steps + 1), diff --git a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py index bc001db8e..69249e610 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle/speculator.py @@ -85,9 +85,8 @@ class EagleSpeculator: self.max_num_tokens, self.hidden_size, dtype=self.dtype, device=device ) - cache_draft_logits = self.speculative_config.rejection_sample_method != "strict" self.draft_logits: torch.Tensor | None = None - if cache_draft_logits: + if self.speculative_config.rejection_sample_method == "probabilistic": self.draft_logits = torch.zeros( self.max_num_reqs, self.num_speculative_steps, diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py index 0c6e26aaa..abb2b90f0 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch +from vllm.config import SpeculativeConfig from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.gpu.input_batch import InputBatch @@ -11,6 +12,10 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.sampler import Sampler from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS +from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import ( + compute_synthetic_rejection_sampler_params, + synthetic_rejection_sample, +) @triton.jit @@ -445,12 +450,26 @@ class RejectionSampler: def __init__( self, sampler: Sampler, - num_speculative_steps, - use_strict_rejection_sampling: bool = True, + spec_config: SpeculativeConfig, ): self.sampler = sampler - self.num_speculative_steps = num_speculative_steps - self.use_strict_rejection_sampling = use_strict_rejection_sampling + self.num_speculative_steps = spec_config.num_speculative_tokens + self.rejection_sample_method = spec_config.rejection_sample_method + if self.rejection_sample_method == "synthetic": + synthetic_acceptance_rate = spec_config.synthetic_acceptance_rate + if ( + synthetic_acceptance_rate is None + or not 0.0 <= synthetic_acceptance_rate <= 1.0 + ): + raise ValueError( + f"synthetic_acceptance_rate must be in [0, 1], " + f"but got {synthetic_acceptance_rate}" + ) + self.base_acceptance_rate, self.decay_factor = ( + compute_synthetic_rejection_sampler_params( + synthetic_acceptance_rate, self.num_speculative_steps + ) + ) def _get_logprobs_tensors( self, @@ -497,7 +516,7 @@ class RejectionSampler: # that num_nans is computed before applying penalties and temperature. num_nans = get_num_nans(logits) if self.sampler.compute_nans else None - if self.use_strict_rejection_sampling: + if self.rejection_sample_method == "strict": sampler_output = self.sampler(logits, input_batch) logprobs_tensors = sampler_output.logprobs_tensors sampled, num_sampled = strict_rejection_sample( @@ -506,7 +525,7 @@ class RejectionSampler: input_batch.cu_num_logits, self.num_speculative_steps, ) - else: + elif self.rejection_sample_method == "probabilistic": assert draft_logits is not None pos = input_batch.positions[input_batch.logits_indices] processed_logits = self.sampler.apply_sampling_params( @@ -538,6 +557,24 @@ class RejectionSampler: if self.sampler.logprobs_mode == "processed_logprobs" else logits, ) + elif self.rejection_sample_method == "synthetic": + sampler_output = self.sampler(logits, input_batch) + logprobs_tensors = sampler_output.logprobs_tensors + sampled, num_sampled = synthetic_rejection_sample( + sampler_output.sampled_token_ids.view(-1), + draft_sampled, + input_batch.cu_num_logits, + input_batch.positions[input_batch.logits_indices], + input_batch.idx_mapping, + self.sampler.sampling_states.seeds.gpu, + self.base_acceptance_rate, + self.decay_factor, + self.num_speculative_steps, + ) + else: + raise ValueError( + f"Unknown rejection sample method: {self.rejection_sample_method}" + ) return SamplerOutput( sampled_token_ids=sampled, diff --git a/vllm/v1/worker/gpu/spec_decode/synthetic_rejection_sampler_utils.py b/vllm/v1/worker/gpu/spec_decode/synthetic_rejection_sampler_utils.py new file mode 100644 index 000000000..f5388575b --- /dev/null +++ b/vllm/v1/worker/gpu/spec_decode/synthetic_rejection_sampler_utils.py @@ -0,0 +1,147 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + +from vllm.triton_utils import tl, triton +from vllm.v1.worker.gpu.sample.gumbel import tl_rand64 + +MIN_ACCEPTANCE_DECAY_FACTOR = 0.85 + + +@triton.jit +def _synthetic_rejection_sample_kernel( + # [num_reqs, num_speculative_steps + 1] + sampled_ptr, + sampled_stride, + # [num_reqs] + num_sampled_ptr, + # [num_draft_tokens + num_reqs] + target_sampled_ptr, + # [num_draft_tokens + num_reqs] + input_ids_ptr, + # [num_reqs + 1] + cu_num_logits_ptr, + # [num_logits] + pos_ptr, + # [num_reqs] + idx_mapping_ptr, + # [max_num_reqs] + seeds_ptr, + base_acceptance_rate, + decay_factor, +): + req_idx = tl.program_id(0) + start_idx = tl.load(cu_num_logits_ptr + req_idx) + end_idx = tl.load(cu_num_logits_ptr + req_idx + 1) + num_tokens = end_idx - start_idx + req_state_idx = tl.load(idx_mapping_ptr + req_idx) + seed = tl.load(seeds_ptr + req_state_idx) + + num_sampled = 0 + acceptance_rate = base_acceptance_rate + rejected = False + for i in range(num_tokens - 1): + if not rejected: + logit_idx = start_idx + i + pos = tl.load(pos_ptr + logit_idx) + u = tl_rand64(seed, pos, includes_zero=False) + if u < acceptance_rate: + sampled = tl.load(input_ids_ptr + logit_idx + 1).to(tl.int64) + else: + sampled = tl.load(target_sampled_ptr + logit_idx) + rejected = True + tl.store(sampled_ptr + req_idx * sampled_stride + i, sampled) + num_sampled += 1 + acceptance_rate *= decay_factor + if not rejected: + target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1) + tl.store( + sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled + ) + num_sampled += 1 + tl.store(num_sampled_ptr + req_idx, num_sampled) + + +def synthetic_rejection_sample( + # [num_draft_tokens + num_reqs] + target_sampled: torch.Tensor, + # [num_draft_tokens + num_reqs] + draft_sampled: torch.Tensor, + # [num_reqs + 1] + cu_num_logits: torch.Tensor, + # [num_logits] + pos: torch.Tensor, + # [num_reqs] + idx_mapping: torch.Tensor, + # [max_num_reqs] + seed: torch.Tensor, + base_acceptance_rate: float, + decay_factor: float, + num_speculative_steps: int, +) -> tuple[torch.Tensor, torch.Tensor]: + num_reqs = cu_num_logits.shape[0] - 1 + sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1) + num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32) + _synthetic_rejection_sample_kernel[(num_reqs,)]( + sampled, + sampled.stride(0), + num_sampled, + target_sampled, + draft_sampled, + cu_num_logits, + pos, + idx_mapping, + seed, + base_acceptance_rate, + decay_factor, + num_warps=1, + ) + return sampled, num_sampled + + +def compute_synthetic_rejection_sampler_params( + p_avg: float, n: int, tol: float = 1e-9 +) -> tuple[float, float]: + def mean_joint_prob(a_0: float, gamma: float, n: int): + total = 0.0 + for i in range(n): + total += a_0 ** (i + 1) * gamma ** (i * (i + 1) // 2) + return total / n + + def min_valid_decay_factor(p: float, n: int, tol: float = 1e-9) -> float: + low, high = MIN_ACCEPTANCE_DECAY_FACTOR, 1.0 + if mean_joint_prob(1, low, n) >= p: + return low + + # Sweep for a gamma decay factor that is guaranteed + # to yield a base acceptance rate <= 1. + while (high - low) > tol: + mid = (low + high) / 2 + if mean_joint_prob(1, mid, n) >= p: + high = mid + else: + low = mid + return high + + def compute_base_acceptance_rate( + p_avg: float, gamma: float, n: int, tol: float = 1e-9 + ) -> float: + if p_avg <= 0.0: + return 0.0 + if p_avg >= 1.0: + return 1.0 + + # Sweep for a base acceptance rate that yields + # the desired mean joint probability. + low, high = 0.0, 1.0 + while (high - low) > tol: + mid = (low + high) / 2 + if mean_joint_prob(mid, gamma, n) >= p_avg: + high = mid + else: + low = mid + return high + + decay_factor = min_valid_decay_factor(p_avg, n) + base_rate = compute_base_acceptance_rate(p_avg, decay_factor, n) + return base_rate, decay_factor