[Model Runner V2] Add probabilistic rejection sampling for spec decoding (#35461)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-11 14:04:32 -07:00
committed by GitHub
parent 12001f2ebc
commit c77181e534
9 changed files with 494 additions and 112 deletions

View File

@@ -57,6 +57,10 @@ SpeculativeMethod = Literal[
EagleModelTypes, EagleModelTypes,
NgramGPUTypes, NgramGPUTypes,
] ]
RejectionSampleMethod = Literal[
"strict",
"probabilistic",
]
@config @config
@@ -171,6 +175,12 @@ class SpeculativeConfig:
"""Load config for the draft model. If not specified, will use the load """Load config for the draft model. If not specified, will use the load
config from the target model.""" config from the target model."""
rejection_sample_method: RejectionSampleMethod = "strict"
"""Whether to use strict (target and draft sampled tokens match exactly)
or probabilistic rejection sampling. Both respect the target model
distribution, but the latter yields a higher acceptance rate at the cost
of more memory to cache draft logits."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
WARNING: Whenever a new field is added to this config, WARNING: Whenever a new field is added to this config,

View File

@@ -90,7 +90,7 @@ from vllm.v1.worker.gpu.spec_decode import init_speculator
from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import ( from vllm.v1.worker.gpu.spec_decode.eagle.eagle3_utils import (
set_eagle3_aux_hidden_state_layers, set_eagle3_aux_hidden_state_layers,
) )
from vllm.v1.worker.gpu.spec_decode.rejection_sample import rejection_sample from vllm.v1.worker.gpu.spec_decode.rejection_sampler import RejectionSampler
from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler from vllm.v1.worker.gpu.spec_decode.utils import DraftTokensHandler
from vllm.v1.worker.gpu.states import RequestState from vllm.v1.worker.gpu.states import RequestState
from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker from vllm.v1.worker.gpu.structured_outputs import StructuredOutputsWorker
@@ -162,6 +162,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.speculator = None self.speculator = None
self.num_speculative_steps = 0 self.num_speculative_steps = 0
self.use_aux_hidden_state_outputs = False self.use_aux_hidden_state_outputs = False
use_strict_rejection_sampling = False
if self.speculative_config is not None: if self.speculative_config is not None:
self.num_speculative_steps = self.speculative_config.num_speculative_tokens self.num_speculative_steps = self.speculative_config.num_speculative_tokens
if self.is_last_pp_rank: if self.is_last_pp_rank:
@@ -172,6 +173,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.use_aux_hidden_state_outputs = True self.use_aux_hidden_state_outputs = True
if self.pp_size > 1: if self.pp_size > 1:
raise ValueError("EAGLE3 with pipeline parallel is not supported.") raise ValueError("EAGLE3 with pipeline parallel is not supported.")
use_strict_rejection_sampling = (
self.speculative_config.rejection_sample_method == "strict"
)
# Draft tokens propagation - for spec-dec + struct outputs. # Draft tokens propagation - for spec-dec + struct outputs.
self.draft_tokens_handler = DraftTokensHandler(self.device) self.draft_tokens_handler = DraftTokensHandler(self.device)
@@ -183,6 +187,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_speculative_steps=self.num_speculative_steps, num_speculative_steps=self.num_speculative_steps,
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
device=self.device, device=self.device,
model_dtype=self.dtype,
cache_draft_logits=not use_strict_rejection_sampling,
) )
self.input_buffers = InputBuffers( self.input_buffers = InputBuffers(
max_num_reqs=self.max_num_reqs, max_num_reqs=self.max_num_reqs,
@@ -197,6 +203,11 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs_mode=self.model_config.logprobs_mode, logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1, num_speculative_tokens=self.num_speculative_steps + 1,
) )
self.rejection_sampler = RejectionSampler(
self.sampler,
num_speculative_steps=self.num_speculative_steps,
use_strict_rejection_sampling=use_strict_rejection_sampling,
)
self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs) self.prompt_logprobs_worker = PromptLogprobsWorker(self.max_num_reqs)
# CUDA graphs. # CUDA graphs.
@@ -412,6 +423,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
next_prefill_tokens=self.req_states.next_prefill_tokens, next_prefill_tokens=self.req_states.next_prefill_tokens,
temperature=self.sampler.sampling_states.temperature.gpu, temperature=self.sampler.sampling_states.temperature.gpu,
seeds=self.sampler.sampling_states.seeds.gpu, seeds=self.sampler.sampling_states.seeds.gpu,
draft_logits_out=self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
dummy_run=True, dummy_run=True,
skip_attn_for_dummy_run=skip_attn, skip_attn_for_dummy_run=skip_attn,
@@ -425,24 +437,16 @@ class GPUModelRunner(LoRAModelRunnerMixin):
def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None:
num_reqs = hidden_states.shape[0] num_reqs = hidden_states.shape[0]
logits = self.model.compute_logits(hidden_states) logits = self.model.compute_logits(hidden_states)
idx_mapping = torch.arange(num_reqs, dtype=torch.int32, device=self.device) dummy_input_batch = InputBatch.make_dummy(
idx_mapping_np = np.arange(num_reqs, dtype=np.int32) num_reqs, num_reqs, self.input_buffers
pos = torch.zeros(num_reqs, dtype=torch.int64, device=self.device)
dummy_input_ids = torch.zeros(num_reqs, dtype=torch.int32, device=self.device)
expanded_local_pos = torch.zeros(
num_reqs, dtype=torch.int32, device=self.device
) )
# NOTE(woosuk): During the initial memory profiling, the sampler may skip # NOTE(woosuk): During the initial memory profiling, the sampler may skip
# top_k, top_p, and logprobs, using less GPU memory than what is possible # top_k, top_p, and logprobs, using less GPU memory than what is possible
# during actual execution. # during actual execution.
self.sampler( self.sampler(
logits, logits,
idx_mapping, dummy_input_batch,
idx_mapping_np,
idx_mapping_np,
pos,
dummy_input_ids,
expanded_local_pos,
) )
@torch.inference_mode() @torch.inference_mode()
@@ -768,8 +772,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
grammar_output: GrammarOutput | None, grammar_output: GrammarOutput | None,
) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]: ) -> tuple[SamplerOutput, torch.Tensor, torch.Tensor]:
sample_hidden_states = hidden_states[input_batch.logits_indices] sample_hidden_states = hidden_states[input_batch.logits_indices]
sample_pos = input_batch.positions[input_batch.logits_indices]
input_ids = input_batch.input_ids[input_batch.logits_indices]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if grammar_output is not None: if grammar_output is not None:
# Apply grammar bitmask to the logits in-place. # Apply grammar bitmask to the logits in-place.
@@ -780,34 +782,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
grammar_output.grammar_bitmask, grammar_output.grammar_bitmask,
) )
# Sample tokens and compute logprobs (if needed).
sampler_output = self.sampler(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
input_batch.cu_num_logits_np,
sample_pos,
input_ids,
input_batch.expanded_local_pos,
)
if input_batch.num_draft_tokens == 0: if input_batch.num_draft_tokens == 0:
# No draft tokens (common case). # No draft tokens (common case).
num_sampled = input_batch.seq_lens.new_ones(input_batch.num_reqs) sampler_output = self.sampler(
logits,
input_batch,
)
else: else:
# Rejection sampling for spec decoding. # Rejection sampling for spec decoding.
sampled_tokens, num_sampled = rejection_sample( sampler_output = self.rejection_sampler(
sampler_output.sampled_token_ids, logits,
input_ids, input_batch,
input_batch.cu_num_logits, # Draft logits are needed for probabilistic rejection sampling.
self.num_speculative_steps, self.req_states.draft_logits[input_batch.idx_mapping]
if self.req_states.draft_logits is not None
else None,
) )
sampler_output.sampled_token_ids = sampled_tokens
# Get the number of sampled and rejected tokens. # Get the number of sampled and rejected tokens.
# For chunked prefills, num_sampled and num_rejected are both 0. # For chunked prefills, num_sampled and num_rejected are both 0.
num_sampled, num_rejected = get_num_sampled_and_rejected( num_sampled, num_rejected = get_num_sampled_and_rejected(
num_sampled, sampler_output.num_sampled,
input_batch.seq_lens, input_batch.seq_lens,
input_batch.cu_num_logits, input_batch.cu_num_logits,
input_batch.idx_mapping, input_batch.idx_mapping,
@@ -1105,6 +1100,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens, self.req_states.next_prefill_tokens,
self.sampler.sampling_states.temperature.gpu, self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu, self.sampler.sampling_states.seeds.gpu,
self.req_states.draft_logits,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
) )
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens

View File

@@ -55,6 +55,8 @@ def _gumbel_sample_kernel(
local_argmax_stride, local_argmax_stride,
local_max_ptr, local_max_ptr,
local_max_stride, local_max_stride,
processed_logits_ptr,
processed_logits_stride,
logits_ptr, logits_ptr,
logits_stride, logits_stride,
expanded_idx_mapping_ptr, expanded_idx_mapping_ptr,
@@ -79,6 +81,20 @@ def _gumbel_sample_kernel(
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if (temp != 0.0) and APPLY_TEMPERATURE:
# Apply temperature.
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Store the temperature-applied logits.
if processed_logits_ptr is not None:
tl.store(
processed_logits_ptr + req_state_idx * processed_logits_stride + block,
logits,
mask=mask,
)
if temp != 0.0: if temp != 0.0:
# Calculate the seed for gumbel noise. # Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx) seed = tl.load(seeds_ptr + req_state_idx)
@@ -90,12 +106,6 @@ def _gumbel_sample_kernel(
u = tl.maximum(u, 1e-7) u = tl.maximum(u, 1e-7)
gumbel_noise = -tl.log(-tl.log(u)) gumbel_noise = -tl.log(-tl.log(u))
# Apply temperature.
if APPLY_TEMPERATURE:
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Apply gumbel noise. # Apply gumbel noise.
logits = tl.where(mask, logits + gumbel_noise, float("-inf")) logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
@@ -112,6 +122,7 @@ def gumbel_sample(
seed: torch.Tensor, # [max_num_reqs] seed: torch.Tensor, # [max_num_reqs]
pos: torch.Tensor, # [num_tokens] pos: torch.Tensor, # [num_tokens]
apply_temperature: bool, apply_temperature: bool,
processed_logits_out: torch.Tensor | None = None, # [num_reqs, vocab_size]
) -> torch.Tensor: ) -> torch.Tensor:
num_tokens, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
@@ -133,6 +144,8 @@ def gumbel_sample(
local_argmax.stride(0), local_argmax.stride(0),
local_max, local_max,
local_max.stride(0), local_max.stride(0),
processed_logits_out,
processed_logits_out.stride(0) if processed_logits_out is not None else 0,
logits, logits,
logits.stride(0), logits.stride(0),
expanded_idx_mapping, expanded_idx_mapping,

View File

@@ -12,3 +12,4 @@ class SamplerOutput:
sampled_token_ids: torch.Tensor sampled_token_ids: torch.Tensor
logprobs_tensors: LogprobsTensors | None logprobs_tensors: LogprobsTensors | None
num_nans: torch.Tensor | None num_nans: torch.Tensor | None
num_sampled: torch.Tensor | None

View File

@@ -7,6 +7,7 @@ import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config.model import LogprobsMode from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
@@ -56,13 +57,15 @@ class Sampler:
def __call__( def __call__(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor, input_batch: InputBatch,
idx_mapping_np: np.ndarray,
cu_num_logits_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> SamplerOutput: ) -> SamplerOutput:
expanded_idx_mapping = input_batch.expanded_idx_mapping
idx_mapping_np = input_batch.idx_mapping_np
cu_num_logits_np = input_batch.cu_num_logits_np
expanded_local_pos = input_batch.expanded_local_pos
pos = input_batch.positions[input_batch.logits_indices]
input_ids = input_batch.input_ids[input_batch.logits_indices]
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature. # that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.compute_nans else None num_nans = get_num_nans(logits) if self.compute_nans else None
@@ -95,10 +98,11 @@ class Sampler:
sampled_token_ids=sampled.view(-1, 1), sampled_token_ids=sampled.view(-1, 1),
logprobs_tensors=logprobs_tensors, logprobs_tensors=logprobs_tensors,
num_nans=num_nans, num_nans=num_nans,
num_sampled=input_batch.seq_lens.new_ones(input_batch.num_reqs),
) )
return sampler_output return sampler_output
def sample( def apply_sampling_params(
self, self,
logits: torch.Tensor, logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor, expanded_idx_mapping: torch.Tensor,
@@ -106,7 +110,7 @@ class Sampler:
pos: torch.Tensor, pos: torch.Tensor,
input_ids: torch.Tensor, input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor, expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> torch.Tensor:
# Copy logits to a new FP32 tensor. # Copy logits to a new FP32 tensor.
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
@@ -143,13 +147,31 @@ class Sampler:
self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np) self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might or might not return a new tensor. # Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p( return self.sampling_states.apply_top_k_top_p(
logits, expanded_idx_mapping, idx_mapping_np logits, expanded_idx_mapping, idx_mapping_np
) )
def sample(
self,
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
processed_logits = self.apply_sampling_params(
logits,
expanded_idx_mapping,
idx_mapping_np,
pos,
input_ids,
expanded_local_pos,
)
# Sample the next token. # Sample the next token.
sampled = gumbel_sample( sampled = gumbel_sample(
logits, processed_logits,
expanded_idx_mapping, expanded_idx_mapping,
self.sampling_states.temperature.gpu, self.sampling_states.temperature.gpu,
self.sampling_states.seeds.gpu, self.sampling_states.seeds.gpu,

View File

@@ -140,6 +140,7 @@ class EagleSpeculator:
slot_mappings: dict[str, torch.Tensor] | None, slot_mappings: dict[str, torch.Tensor] | None,
num_tokens_across_dp: torch.Tensor | None, num_tokens_across_dp: torch.Tensor | None,
cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
draft_logits_out: torch.Tensor | None = None,
) -> None: ) -> None:
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1] query_start_loc = self.input_buffers.query_start_loc[: num_reqs + 1]
@@ -166,6 +167,9 @@ class EagleSpeculator:
self.seeds, self.seeds,
pos + 1, pos + 1,
apply_temperature=True, apply_temperature=True,
processed_logits_out=draft_logits_out[:, step]
if draft_logits_out is not None
else None,
) )
self.draft_tokens[:num_reqs, step] = draft_tokens self.draft_tokens[:num_reqs, step] = draft_tokens
@@ -219,6 +223,8 @@ class EagleSpeculator:
temperature: torch.Tensor, temperature: torch.Tensor,
# [max_num_reqs] # [max_num_reqs]
seeds: torch.Tensor, seeds: torch.Tensor,
# [max_num_reqs, num_speculative_steps, vocab_size]
draft_logits_out: torch.Tensor | None,
num_tokens_across_dp: torch.Tensor | None = None, num_tokens_across_dp: torch.Tensor | None = None,
dummy_run: bool = False, dummy_run: bool = False,
skip_attn_for_dummy_run: bool = False, skip_attn_for_dummy_run: bool = False,
@@ -271,6 +277,7 @@ class EagleSpeculator:
idx_mapping.copy_(input_batch.idx_mapping) idx_mapping.copy_(input_batch.idx_mapping)
self.temperature.copy_(temperature) self.temperature.copy_(temperature)
self.seeds.copy_(seeds) self.seeds.copy_(seeds)
# Gather the values and copy them to the pre-allocated buffers. # Gather the values and copy them to the pre-allocated buffers.
pos = self.input_buffers.positions[:num_reqs] pos = self.input_buffers.positions[:num_reqs]
torch.gather(input_batch.positions, 0, last_token_indices, out=pos) torch.gather(input_batch.positions, 0, last_token_indices, out=pos)
@@ -283,7 +290,11 @@ class EagleSpeculator:
self.seeds, self.seeds,
pos + 1, pos + 1,
apply_temperature=True, apply_temperature=True,
processed_logits_out=draft_logits_out[:, 0]
if draft_logits_out is not None
else None,
) )
if self.num_speculative_steps == 1: if self.num_speculative_steps == 1:
# Early exit. # Early exit.
return draft_tokens.view(-1, 1) return draft_tokens.view(-1, 1)
@@ -365,6 +376,7 @@ class EagleSpeculator:
slot_mappings_updated, slot_mappings_updated,
num_tokens_across_dp=num_tokens_across_dp, num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=batch_desc.cg_mode, cudagraph_runtime_mode=batch_desc.cg_mode,
draft_logits_out=draft_logits_out,
) )
return self.draft_tokens[:num_reqs] return self.draft_tokens[:num_reqs]

View File

@@ -1,62 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
@triton.jit
def _rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
input_ids: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_sampled = cu_num_logits.new_empty(num_reqs)
_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
input_ids,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled

View File

@@ -0,0 +1,375 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
@triton.jit
def _strict_rejection_sample_kernel(
sampled_ptr, # [num_reqs, num_speculative_steps + 1]
sampled_stride,
num_sampled_ptr, # [num_reqs]
target_sampled_ptr, # [num_draft_tokens + num_reqs]
input_ids_ptr, # [num_draft_tokens + num_reqs]
cu_num_logits_ptr, # [num_reqs + 1]
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
num_sampled = 0
rejected = False
for i in range(num_tokens - 1):
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + i)
draft_sampled = tl.load(input_ids_ptr + start_idx + i + 1)
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_sampled)
num_sampled += 1
if target_sampled != draft_sampled:
rejected = True
if not rejected:
target_sampled = tl.load(target_sampled_ptr + start_idx + num_tokens - 1)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_tokens - 1, target_sampled
)
num_sampled += 1
tl.store(num_sampled_ptr + req_idx, num_sampled)
def strict_rejection_sample(
# [num_draft_tokens + num_reqs]
target_sampled: torch.Tensor,
# [num_draft_tokens + num_reqs]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
num_speculative_steps,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_strict_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_sampled,
draft_sampled,
cu_num_logits,
num_warps=1,
)
return sampled, num_sampled
@triton.jit
def _probabilistic_rejection_sample_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
rejected_steps_ptr,
# [num_logits]
draft_sampled_ptr,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_reqs, num_speculative_steps, V]
draft_probs_ptr,
draft_probs_stride_0,
draft_probs_stride_1,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [num_reqs]
seeds_ptr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
seed = tl.load(seeds_ptr + tl.load(idx_mapping_ptr + req_idx))
rejected_step = 0
accepted = True
for i in range(num_tokens - 1):
if accepted:
draft_sampled = tl.load(draft_sampled_ptr + start_idx + i + 1)
target_prob = tl.load(
target_probs_ptr + (start_idx + i) * target_probs_stride + draft_sampled
)
draft_prob = tl.load(
draft_probs_ptr
+ req_idx * draft_probs_stride_0
+ i * draft_probs_stride_1
+ draft_sampled
)
pos = tl.load(pos_ptr + start_idx + i)
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
accepted &= target_prob > u * draft_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step)
@triton.jit
def _compute_residual_logits_kernel(
# [num_reqs, V]
residual_logits_ptr,
residual_logits_stride,
# [num_reqs]
residual_pos_ptr,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_reqs, num_speculative_steps, V]
draft_probs_ptr,
draft_probs_stride_0,
draft_probs_stride_1,
# [num_reqs]
rejected_step_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
rejected_draft_step = tl.load(rejected_step_ptr + req_idx)
rejected_logit_idx = start_idx + rejected_draft_step
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
if rejected_logit_idx < end_idx - 1:
target_probs = tl.load(
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
draft_probs = tl.load(
draft_probs_ptr
+ req_idx * draft_probs_stride_0
+ rejected_draft_step * draft_probs_stride_1
+ block_offsets,
mask=mask,
other=0.0,
)
residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
residual_logits = tl.log(residual_probs)
else:
# This is a bonus token. Directly return the target logits.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=0.0,
)
tl.store(
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
residual_logits,
mask=mask,
)
# First block computes the residual logit positions.
if block_idx == 0:
pos_val = tl.load(pos_ptr + rejected_logit_idx)
tl.store(residual_pos_ptr + req_idx, pos_val)
def probabilistic_rejection_sample(
# [num_draft_tokens + num_reqs, V]
target_logits: torch.Tensor,
# [num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor,
# [num_draft_tokens + num_reqs]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
# [num_logits]
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
temperature,
seeds,
num_speculative_steps,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
device = target_logits.device
vocab_size = target_logits.shape[-1]
# Compute target and draft probs.
target_probs = torch.softmax(target_logits, dim=-1)
draft_probs = torch.softmax(draft_logits, dim=-1)
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=torch.int64,
device=device,
)
# [num_reqs]
rejected_steps = torch.empty(
num_reqs,
dtype=torch.int64,
device=device,
)
_probabilistic_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
rejected_steps,
draft_sampled,
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
draft_probs.stride(1),
cu_num_logits,
pos,
idx_mapping,
seeds,
num_warps=1,
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits = torch.empty(
num_reqs,
vocab_size,
dtype=target_logits.dtype,
device=device,
)
# [num_reqs]
residual_pos = torch.empty(
num_reqs,
dtype=pos.dtype,
device=device,
)
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
residual_logits,
residual_logits.stride(0),
residual_pos,
target_logits,
target_logits.stride(0),
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
draft_probs.stride(1),
rejected_steps,
cu_num_logits,
pos,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Gumbel sample tokens from the residual distribution.
resampled = gumbel_sample(
residual_logits,
idx_mapping,
temperature,
seeds,
residual_pos,
apply_temperature=False,
)
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
return sampled, rejected_steps + 1
class RejectionSampler:
def __init__(
self,
sampler: Sampler,
num_speculative_steps,
use_strict_rejection_sampling: bool = True,
):
self.sampler = sampler
self.num_speculative_steps = num_speculative_steps
self.use_strict_rejection_sampling = use_strict_rejection_sampling
def __call__(
self,
logits: torch.Tensor,
input_batch: InputBatch,
draft_logits: torch.Tensor | None = None,
) -> SamplerOutput:
draft_sampled = input_batch.input_ids[input_batch.logits_indices]
# NOTE(woosuk): We intentionally compute num_nans before sampling to make clear
# that num_nans is computed before applying penalties and temperature.
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
if self.use_strict_rejection_sampling:
sampler_output = self.sampler(
logits,
input_batch,
)
logprobs_tensors = sampler_output.logprobs_tensors
sampled, num_sampled = strict_rejection_sample(
sampler_output.sampled_token_ids.view(-1),
draft_sampled,
input_batch.cu_num_logits,
self.num_speculative_steps,
)
else:
assert draft_logits is not None
pos = input_batch.positions[input_batch.logits_indices]
processed_logits = self.sampler.apply_sampling_params(
logits,
input_batch.expanded_idx_mapping,
input_batch.idx_mapping_np,
pos,
draft_sampled,
input_batch.expanded_local_pos,
)
# TODO (TheEpicDolphin): Return logprobs for sampled token ids.
logprobs_tensors = None
sampled, num_sampled = probabilistic_rejection_sample(
processed_logits,
draft_logits,
draft_sampled,
input_batch.cu_num_logits,
pos,
input_batch.idx_mapping,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps,
)
return SamplerOutput(
sampled_token_ids=sampled,
logprobs_tensors=logprobs_tensors,
num_nans=num_nans,
num_sampled=num_sampled,
)

View File

@@ -15,6 +15,8 @@ class RequestState:
num_speculative_steps: int, num_speculative_steps: int,
vocab_size: int, vocab_size: int,
device: torch.device, device: torch.device,
model_dtype: torch.dtype,
cache_draft_logits: bool,
): ):
self.max_num_reqs = max_num_reqs self.max_num_reqs = max_num_reqs
self.max_model_len = max_model_len self.max_model_len = max_model_len
@@ -70,6 +72,19 @@ class RequestState:
dtype=torch.int64, dtype=torch.int64,
device=device, device=device,
) )
# Draft token logits.
# NOTE: This tensor maintains the "processed" logits after applying temperature,
# top-p, etc.
self.draft_logits: torch.Tensor | None = None
if cache_draft_logits:
self.draft_logits = torch.zeros(
self.max_num_reqs,
self.num_speculative_steps,
self.vocab_size,
dtype=model_dtype,
device=device,
)
self.next_prefill_tokens = torch.zeros( self.next_prefill_tokens = torch.zeros(
self.max_num_reqs, dtype=torch.int32, device=device self.max_num_reqs, dtype=torch.int32, device=device
) )