[Model Runner V2] Add probabilistic rejection sampling for spec decoding (#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-11 14:04:32 -07:00
committed by GitHub
parent 12001f2ebc
commit c77181e534
9 changed files with 494 additions and 112 deletions

View File

@@ -57,6 +57,10 @@ SpeculativeMethod = Literal[
EagleModelTypes,
NgramGPUTypes,
]
RejectionSampleMethod = Literal[
"strict",
"probabilistic",
]
@config
@@ -171,6 +175,12 @@ class SpeculativeConfig:
"""Load config for the draft model. If not specified, will use the load
config from the target model."""
rejection_sample_method: RejectionSampleMethod = "strict"
"""Whether to use strict (target and draft sampled tokens match exactly)
or probabilistic rejection sampling. Both respect the target model
distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits."""
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,

View File

@@ -90,7 +90,7 @@ from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
set_eagle3_aux_hidden_state_layers,
)
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample
from vllm.v1.worker.gpu.spec_decode.rejection_sampler import RejectionSampler
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculator = None
self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False
use_strict_rejection_sampling = False
if self.speculative_config is not None:
self.num_speculative_steps = self.speculative_config.num_speculative_tokens
if self.is_last_pp_rank:
@@ -172,6 +173,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_aux_hidden_state_outputs = True
if self.pp_size > 1:
raise ValueError("EAGLE3 with pipeline parallel is not supported.")
use_strict_rejection_sampling = (
self.speculative_config.rejection_sample_method == "strict"
)
# Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device)
@@ -183,6 +187,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size,
device=self.device,
model_dtype=self.dtype,
cache_draft_logits=not use_strict_rejection_sampling,
)
self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs,
@@ -197,6 +203,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
# CUDA graphs.
@@ -412,6 +423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu,
draft_logits_out=self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True,
skip_attn_for_dummy_run=skip_attn,
@@ -425,24 +437,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
num_reqs = hidden_states.shape[0]
logits = self.model.compute_logits(hidden_states)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device)
idx_mapping_np = np.arange(num_reqs, dtype=np.int32)
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
expanded_local_pos = torch.zeros(
num_reqs, dtype=torch.int32, device=self.device
dummy_input_batch = InputBatch.make_dummy(
num_reqs, num_reqs, self.input_buffers
)
# NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution.
self.sampler(
logits,
idx_mapping,
idx_mapping_np,
idx_mapping_np,
pos,
dummy_input_ids,
expanded_local_pos,
dummy_input_batch,
)
@torch.inference_mode()
@@ -768,8 +772,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices]
sample_pos = input_batch.positions[input_batch.logits_indices]
input_ids = input_batch.input_ids[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None:
# Apply grammar bitmask to the logits in-place.
@@ -780,34 +782,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
grammar_output.grammar_bitmask,
)
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
input_batch.cu_num_logits_np,
sample_pos,
input_ids,
input_batch.expanded_local_pos,
)
if input_batch.num_draft_tokens == 0:
# No draft tokens (common case).
num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs)
sampler_output = self.sampler(
logits,
input_batch,
)
else:
# Rejection sampling for spec decoding.
sampled_tokens, num_sampled = rejection_sample(
sampler_output.sampled_token_ids,
input_ids,
input_batch.cu_num_logits,
self.num_speculative_steps,
sampler_output = self.rejection_sampler(
logits,
input_batch,
# Draft logits are needed for probabilistic rejection sampling.
self.req_states.draft_logits[input_batch.idx_mapping]
if self.req_states.draft_logits is not None
else None,
)
sampler_output.sampled_token_ids = sampled_tokens
# Get the number of sampled and rejected tokens.
# For chunked prefills, num_sampled and num_rejected are both 0.
num_sampled, num_rejected = get_num_sampled_and_rejected(
num_sampled,
sampler_output.num_sampled,
input_batch.seq_lens,
input_batch.cu_num_logits,
input_batch.idx_mapping,
@@ -1105,6 +1100,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp,
)
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens

View File

@@ -55,6 +55,8 @@ def _gumbel_sample_kernel(
local_argmax_stride,
local_max_ptr,
local_max_stride,
processed_logits_ptr,
processed_logits_stride,
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr,
@@ -79,6 +81,20 @@ def _gumbel_sample_kernel(
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if (temp != 0.0) and APPLY_TEMPERATURE:
# Apply temperature.
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Store the temperature-applied logits.
if processed_logits_ptr is not None:
tl.store(
processed_logits_ptr + req_state_idx * processed_logits_stride + block,
logits,
mask=mask,
)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
@@ -90,12 +106,6 @@ def _gumbel_sample_kernel(
u = tl.maximum(u, 1e-7)
gumbel_noise = -tl.log(-tl.log(u))
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
@@ -112,6 +122,7 @@ def gumbel_sample(
seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_tokens]
apply_temperature: bool,
processed_logits_out: torch.Tensor | None = None, # [num_reqs, vocab_size]
) -> torch.Tensor:
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
@@ -133,6 +144,8 @@ def gumbel_sample(
local_argmax.stride(0),
local_max,
local_max.stride(0),
processed_logits_out,
processed_logits_out.stride(0) if processed_logits_out is not None else 0,
logits,
logits.stride(0),
expanded_idx_mapping,

View File

@@ -12,3 +12,4 @@ class SamplerOutput:
sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None
num_sampled: torch.Tensor | None

View File

@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
@@ -56,13 +57,15 @@ class Sampler:
def __call__(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
input_batch: InputBatch,
) -> SamplerOutput:
expanded_idx_mapping = input_batch.expanded_idx_mapping
idx_mapping_np = input_batch.idx_mapping_np
cu_num_logits_np = input_batch.cu_num_logits_np
expanded_local_pos = input_batch.expanded_local_pos
pos = input_batch.positions[input_batch.logits_indices]
input_ids = input_batch.input_ids[input_batch.logits_indices]
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None
@@ -95,10 +98,11 @@ class Sampler:
sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
num_sampled=input_batch.seq_lens.new_ones(input_batch.num_reqs),
)
return sampler_output
def sample(
def apply_sampling_params(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
@@ -106,7 +110,7 @@ class Sampler:
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
# Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
@@ -143,13 +147,31 @@ class Sampler:
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
return self.sampling_states.apply_top_k_top_p(
logits, expanded_idx_mapping, idx_mapping_np
)
def sample(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
processed_logits = self.apply_sampling_params(
logits,
expanded_idx_mapping,
idx_mapping_np,
pos,
input_ids,
expanded_local_pos,
)
# Sample the next token.
sampled = gumbel_sample(
logits,
processed_logits,
expanded_idx_mapping,
self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu,

View File

@@ -140,6 +140,7 @@ class EagleSpeculator:
slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
draft_logits_out: torch.Tensor | None = None,
) -> None:
pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
@@ -166,6 +167,9 @@ class EagleSpeculator:
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, step]
if draft_logits_out is not None
else None,
)
self.draft_tokens[:num_reqs, step] = draft_tokens
@@ -219,6 +223,8 @@ class EagleSpeculator:
temperature: torch.Tensor,
# [max_num_reqs]
seeds: torch.Tensor,
# [max_num_reqs, num_speculative_steps, vocab_size]
draft_logits_out: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False,
@@ -271,6 +277,7 @@ class EagleSpeculator:
idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(temperature)
self.seeds.copy_(seeds)
# Gather the values and copy them to the pre-allocated buffers.
pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
@@ -283,7 +290,11 @@ class EagleSpeculator:
self.seeds,
pos + 1,
apply_temperature=True,
processed_logits_out=draft_logits_out[:, 0]
if draft_logits_out is not None
else None,
)
if self.num_speculative_steps == 1:
# Early exit.
return draft_tokens.view(-1, 1)
@@ -365,6 +376,7 @@ class EagleSpeculator:
slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode,
draft_logits_out=draft_logits_out,
)
return self.draft_tokens[:num_reqs]

View File

@@ -1,62 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_sampled = cu_num_logits.new_empty(num_reqs)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled

View File

@@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
@triton.jit
def _strict_rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def strict_rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_strict_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
draft_sampled,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled
@triton.jit
def _probabilistic_rejection_sample_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
rejected_steps_ptr,
# [num_logits]
draft_sampled_ptr,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_reqs, num_speculative_steps, V]
draft_probs_ptr,
draft_probs_stride_0,
draft_probs_stride_1,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [num_reqs]
seeds_ptr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
seed = tl.load(seeds_ptr + tl.load(idx_mapping_ptr + req_idx))
rejected_step = 0
accepted = True
for i in range(num_tokens - 1):
if accepted:
draft_sampled = tl.load(draft_sampled_ptr + start_idx + i + 1)
target_prob = tl.load(
target_probs_ptr + (start_idx + i) * target_probs_stride + draft_sampled
)
draft_prob = tl.load(
draft_probs_ptr
+ req_idx * draft_probs_stride_0
+ i * draft_probs_stride_1
+ draft_sampled
)
pos = tl.load(pos_ptr + start_idx + i)
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
accepted &= target_prob > u * draft_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step)
@triton.jit
def _compute_residual_logits_kernel(
# [num_reqs, V]
residual_logits_ptr,
residual_logits_stride,
# [num_reqs]
residual_pos_ptr,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_reqs, num_speculative_steps, V]
draft_probs_ptr,
draft_probs_stride_0,
draft_probs_stride_1,
# [num_reqs]
rejected_step_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
rejected_draft_step = tl.load(rejected_step_ptr + req_idx)
rejected_logit_idx = start_idx + rejected_draft_step
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
if rejected_logit_idx < end_idx - 1:
target_probs = tl.load(
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
draft_probs = tl.load(
draft_probs_ptr
+ req_idx * draft_probs_stride_0
+ rejected_draft_step * draft_probs_stride_1
+ block_offsets,
mask=mask,
other=0.0,
)
residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
residual_logits = tl.log(residual_probs)
else:
# This is a bonus token. Directly return the target logits.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=0.0,
)
tl.store(
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
residual_logits,
mask=mask,
)
# First block computes the residual logit positions.
if block_idx == 0:
pos_val = tl.load(pos_ptr + rejected_logit_idx)
tl.store(residual_pos_ptr + req_idx, pos_val)
def probabilistic_rejection_sample(
# [num_draft_tokens + num_reqs, V]
target_logits: torch.Tensor,
# [num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor,
# [num_draft_tokens + num_reqs]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
# [num_logits]
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
temperature,
seeds,
num_speculative_steps,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
device = target_logits.device
vocab_size = target_logits.shape[-1]
# Compute target and draft probs.
target_probs = torch.softmax(target_logits, dim=-1)
draft_probs = torch.softmax(draft_logits, dim=-1)
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=torch.int64,
device=device,
)
# [num_reqs]
rejected_steps = torch.empty(
num_reqs,
dtype=torch.int64,
device=device,
)
_probabilistic_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
rejected_steps,
draft_sampled,
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
draft_probs.stride(1),
cu_num_logits,
pos,
idx_mapping,
seeds,
num_warps=1,
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits = torch.empty(
num_reqs,
vocab_size,
dtype=target_logits.dtype,
device=device,
)
# [num_reqs]
residual_pos = torch.empty(
num_reqs,
dtype=pos.dtype,
device=device,
)
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
residual_logits,
residual_logits.stride(0),
residual_pos,
target_logits,
target_logits.stride(0),
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
draft_probs.stride(1),
rejected_steps,
cu_num_logits,
pos,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Gumbel sample tokens from the residual distribution.
resampled = gumbel_sample(
residual_logits,
idx_mapping,
temperature,
seeds,
residual_pos,
apply_temperature=False,
)
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
return sampled, rejected_steps + 1
class RejectionSampler:
def __init__(
self,
sampler: Sampler,
num_speculative_steps,
use_strict_rejection_sampling: bool = True,
):
self.sampler = sampler
self.num_speculative_steps = num_speculative_steps
self.use_strict_rejection_sampling = use_strict_rejection_sampling
def __call__(
self,
logits: torch.Tensor,
input_batch: InputBatch,
draft_logits: torch.Tensor | None = None,
) -> SamplerOutput:
draft_sampled = input_batch.input_ids[input_batch.logits_indices]
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
if self.use_strict_rejection_sampling:
sampler_output = self.sampler(
logits,
input_batch,
)
logprobs_tensors = sampler_output.logprobs_tensors
sampled, num_sampled = strict_rejection_sample(
sampler_output.sampled_token_ids.view(-1),
draft_sampled,
input_batch.cu_num_logits,
self.num_speculative_steps,
)
else:
assert draft_logits is not None
pos = input_batch.positions[input_batch.logits_indices]
processed_logits = self.sampler.apply_sampling_params(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
pos,
draft_sampled,
input_batch.expanded_local_pos,
)
# TODO (TheEpicDolphin): Return logprobs for sampled token ids.
logprobs_tensors = None
sampled, num_sampled = probabilistic_rejection_sample(
processed_logits,
draft_logits,
draft_sampled,
input_batch.cu_num_logits,
pos,
input_batch.idx_mapping,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps,
)
return SamplerOutput(
sampled_token_ids=sampled,
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
num_sampled=num_sampled,
)

View File

@@ -15,6 +15,8 @@ class RequestState:
num_speculative_steps: int,
vocab_size: int,
device: torch.device,
model_dtype: torch.dtype,
cache_draft_logits: bool,
):
self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len
@@ -70,6 +72,19 @@ class RequestState:
dtype=torch.int64,
device=device,
)
# Draft token logits.
# NOTE: This tensor maintains the "processed" logits after applying temperature,
# top-p, etc.
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=model_dtype,
device=device,
)
self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device
)