[Model Runner V2] Fuse probabilistic rejection sample kernels (#38496)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-04-07 17:37:37 -07:00
committed by GitHub
parent ad3304425b
commit 5daf62271d
5 changed files with 886 additions and 377 deletions

View File

@@ -100,11 +100,13 @@ steps:
- vllm/v1/worker/gpu/
- vllm/v1/worker/gpu_worker.py
- tests/v1/spec_decode/test_max_len.py
- tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- tests/v1/e2e/spec_decode/test_spec_decode.py
commands:
- set -x
- export VLLM_USE_V2_MODEL_RUNNER=1
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
- pytest -v -s v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"

View File

@@ -0,0 +1,215 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
import pytest
import torch
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
probabilistic_rejection_sample,
)
VOCAB_SIZE = 4096
# Skip if no CUDA - Triton kernel requires GPU
pytest.importorskip("triton")
if not torch.cuda.is_available():
pytest.skip("CUDA required for rejection sampler tests", allow_module_level=True)
def _build_rejection_sample_inputs(
target_logits_1d: torch.Tensor,
draft_logits_1d: torch.Tensor,
num_speculative_steps: int,
temperature: float,
num_trials: int,
) -> dict:
device = target_logits_1d.device
vocab_size = target_logits_1d.shape[0]
K = num_speculative_steps
num_logits = num_trials * (K + 1)
target_logits = target_logits_1d.unsqueeze(0).expand(num_logits, -1).contiguous()
draft_logits = (
draft_logits_1d.view(1, 1, vocab_size).expand(num_trials, K, -1).contiguous()
)
draft_probs = torch.softmax(draft_logits_1d, dim=0)
draft_tokens = torch.multinomial(
draft_probs.expand(num_trials, -1), K, replacement=True
)
draft_sampled_2d = torch.zeros(num_trials, K + 1, dtype=torch.int64, device=device)
draft_sampled_2d[:, 1:] = draft_tokens
draft_sampled = draft_sampled_2d.reshape(-1)
cu_num_logits = torch.arange(num_trials + 1, dtype=torch.int32, device=device) * (
K + 1
)
pos = torch.arange(num_logits, dtype=torch.int32, device=device)
idx_mapping = torch.arange(num_trials, dtype=torch.int32, device=device)
expanded_idx_mapping = torch.arange(
num_trials, dtype=torch.int32, device=device
).repeat_interleave(K + 1)
expanded_local_pos = torch.arange(K + 1, dtype=torch.int32, device=device).repeat(
num_trials
)
temp_tensor = torch.full(
(num_trials,), temperature, dtype=torch.float32, device=device
)
seed = torch.arange(num_trials, dtype=torch.int64, device=device)
return dict(
target_logits=target_logits,
draft_logits=draft_logits,
draft_sampled=draft_sampled,
cu_num_logits=cu_num_logits,
pos=pos,
idx_mapping=idx_mapping,
expanded_idx_mapping=expanded_idx_mapping,
expanded_local_pos=expanded_local_pos,
temperature=temp_tensor,
seed=seed,
)
def _assert_distribution_match(
sampled_tokens: torch.Tensor,
target_probs: torch.Tensor,
device: str,
label: str = "",
min_expected: float = 5.0,
):
"""
Assert sampled tokens match the target distribution via a
chi-squared goodness-of-fit test. This is done by computing
observed vs expected token counts (target_probs * num_samples),
then checking that the chi-squared statistic is below a conservative
threshold. The threshold is set at df + 10*sqrt(2*df), which
corresponds to ~10 sigma under the chi-squared distribution's
normal approximation, effectively disallowing false positives.
NOTE: Tokens with expected count < min_expected are merged into
a single "other" bin to minimize chi-squared noise.
"""
num_samples = sampled_tokens.shape[0]
vocab_size = target_probs.shape[0]
observed = torch.zeros(vocab_size, device=device, dtype=torch.float32)
observed.scatter_add_(0, sampled_tokens, torch.ones(num_samples, device=device))
expected = target_probs * num_samples
sufficient = expected >= min_expected
obs_main = observed[sufficient]
exp_main = expected[sufficient]
obs_other = observed[~sufficient].sum().unsqueeze(0)
exp_other = expected[~sufficient].sum().unsqueeze(0)
if exp_other.item() >= min_expected:
obs_all = torch.cat([obs_main, obs_other])
exp_all = torch.cat([exp_main, exp_other])
else:
obs_all = obs_main
exp_all = exp_main
chi2 = ((obs_all - exp_all) ** 2 / exp_all).sum().item()
df = obs_all.shape[0] - 1
if df < 1:
# All samples were merged into < 2 bins, which is too
# few to evaluate.
return
threshold = df + 10 * math.sqrt(2 * df)
prefix = f"[{label}] " if label else ""
assert chi2 < threshold, (
f"{prefix}Chi-squared test failed: chi2={chi2:.1f}, "
f"df={df}, threshold={threshold:.1f}. "
f"Output distribution does not match target distribution."
)
@pytest.mark.parametrize(
"num_speculative_steps,temperature",
[
(1, 0.6),
(3, 0.6),
(1, 1.0),
(3, 1.0),
],
)
def test_stochastic_rejection_sample(num_speculative_steps: int, temperature: float):
"""
Verify that rejection sampling produces the target distribution.
This is done by simulating many independent trials of speculative
decoding (from a fixed target and draft distribution). We then
run rejection sample on all of the trials (requests), and verify
that the sampled tokens at every position follow the target
distribution p(x).
"""
torch.manual_seed(42)
device = "cuda"
num_trials = 10 * VOCAB_SIZE
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
if temperature > 0:
target_logits_1d /= temperature
draft_logits_1d /= temperature
inputs = _build_rejection_sample_inputs(
target_logits_1d,
draft_logits_1d,
num_speculative_steps,
temperature=temperature,
num_trials=num_trials,
)
sampled, num_sampled = probabilistic_rejection_sample(
**inputs, num_speculative_steps=num_speculative_steps
)
target_probs = torch.softmax(target_logits_1d, dim=0)
for pos in range(num_speculative_steps + 1):
accepted_mask = num_sampled >= pos + 1
_assert_distribution_match(
sampled[accepted_mask, pos], target_probs, device, label=f"position {pos}"
)
@pytest.mark.parametrize("num_speculative_steps", [1, 3])
def test_greedy_rejection_sample(num_speculative_steps: int):
"""
Verify that greedy (temperature=0) always outputs the target argmax
at every accepted position.
"""
torch.manual_seed(42)
device = "cuda"
num_trials = 10 * VOCAB_SIZE
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
inputs = _build_rejection_sample_inputs(
target_logits_1d,
draft_logits_1d,
num_speculative_steps,
temperature=0.0,
num_trials=num_trials,
)
sampled, num_sampled = probabilistic_rejection_sample(
**inputs, num_speculative_steps=num_speculative_steps
)
target_argmax = target_logits_1d.argmax().item()
steps = torch.arange(num_speculative_steps + 1, device=device).unsqueeze(0)
accepted_mask = steps < num_sampled.unsqueeze(1)
assert (sampled[accepted_mask] == target_argmax).all(), (
"Greedy sampling produced tokens that are not the target argmax"
)

View File

@@ -65,36 +65,20 @@ def tl_rand64(seed, offset, includes_zero: tl.constexpr):
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
processed_logits_ptr,
processed_logits_stride,
logits_ptr,
logits_stride,
def gumbel_block_argmax(
logits,
block,
mask,
token_idx,
expanded_idx_mapping_ptr,
temp_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
processed_logits_ptr,
processed_logits_stride,
APPLY_TEMPERATURE: tl.constexpr,
):
token_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp != 0.0 and APPLY_TEMPERATURE:
# Apply temperature.
@@ -102,8 +86,8 @@ def _gumbel_sample_kernel(
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
logits = logits / temp
# Store the temperature-applied logits.
if processed_logits_ptr is not None:
# Store the temperature-applied logits.
tl.store(
processed_logits_ptr + req_state_idx * processed_logits_stride + block,
logits,
@@ -126,6 +110,51 @@ def _gumbel_sample_kernel(
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
value, idx = tl.max(logits, axis=0, return_indices=True)
return value, idx
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
local_argmax_stride,
local_max_ptr,
local_max_stride,
processed_logits_ptr,
processed_logits_stride,
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr,
seeds_ptr,
pos_ptr,
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
APPLY_TEMPERATURE: tl.constexpr,
):
token_idx = tl.program_id(0)
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
logits = tl.load(
logits_ptr + token_idx * logits_stride + block,
mask=mask,
other=float("-inf"),
)
logits = logits.to(tl.float32)
value, idx = gumbel_block_argmax(
logits,
block,
mask,
token_idx,
expanded_idx_mapping_ptr,
temp_ptr,
seeds_ptr,
pos_ptr,
processed_logits_ptr,
processed_logits_stride,
APPLY_TEMPERATURE=APPLY_TEMPERATURE,
)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)

View File

@@ -0,0 +1,612 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.sample.gumbel import gumbel_block_argmax, tl_rand64
@triton.jit
def _compute_block_max_and_sumexp(logits):
block_max = tl.max(logits, axis=0)
block_sumexp = tl.where(
block_max > float("-inf"),
tl.sum(tl.exp(logits - block_max)),
0.0,
)
return block_max, block_sumexp
@triton.jit
def _compute_global_lse(
local_max_ptr,
local_max_stride,
local_sumexp_ptr,
local_sumexp_stride,
logit_idx,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
):
blocks = tl.arange(0, PADDED_VOCAB_NUM_BLOCKS)
blocks_mask = blocks < vocab_num_blocks
maxes = tl.load(
local_max_ptr + logit_idx * local_max_stride + blocks,
mask=blocks_mask,
other=float("-inf"),
)
sumexps = tl.load(
local_sumexp_ptr + logit_idx * local_sumexp_stride + blocks,
mask=blocks_mask,
other=0.0,
)
global_max = tl.max(maxes, axis=0)
global_lse = global_max + tl.log(tl.sum(sumexps * tl.exp(maxes - global_max)))
return global_lse
@triton.jit
def _compute_block_max_and_sumexp_kernel(
# [num_logits, num_blocks]
target_local_argmax_ptr,
target_local_argmax_stride,
# [num_logits, num_blocks]
target_local_max_ptr,
target_local_max_stride,
# [num_logits, num_blocks]
target_local_sumexp_ptr,
target_local_sumexp_stride,
# [num_logits, num_blocks]
draft_local_max_ptr,
draft_local_max_stride,
# [num_logits, num_blocks]
draft_local_sumexp_ptr,
draft_local_sumexp_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr,
draft_logits_stride_0,
draft_logits_stride_1,
# [num_logits]
expanded_idx_mapping_ptr,
# [num_logits]
expanded_local_pos_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
num_speculative_steps,
BLOCK_SIZE: tl.constexpr,
):
logit_idx = tl.program_id(0)
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
if draft_step_idx >= num_speculative_steps:
# Bonus token. Max/argmax and summed exponentials are not needed.
return
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
block_idx = tl.program_id(1)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
if temp == 0.0:
# Greedy sampling. Only the target max/argmax are needed.
target_logits = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
value, idx = tl.max(target_logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(
target_local_argmax_ptr
+ logit_idx * target_local_argmax_stride
+ block_idx,
token_id,
)
tl.store(
target_local_max_ptr + logit_idx * target_local_max_stride + block_idx,
value,
)
else:
# Get local draft max and summed exponentials.
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ draft_step_idx * draft_logits_stride_1
+ block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
draft_max, draft_sumexp = _compute_block_max_and_sumexp(draft_logits)
tl.store(
draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx,
draft_max,
)
tl.store(
draft_local_sumexp_ptr + logit_idx * draft_local_sumexp_stride + block_idx,
draft_sumexp,
)
# Get local target max and summed exponentials.
target_logits = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
target_max, target_sumexp = _compute_block_max_and_sumexp(target_logits)
tl.store(
target_local_max_ptr + logit_idx * target_local_max_stride + block_idx,
target_max,
)
tl.store(
target_local_sumexp_ptr
+ logit_idx * target_local_sumexp_stride
+ block_idx,
target_sumexp,
)
@triton.jit
def _probabilistic_rejection_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
rejected_steps_ptr,
# [num_reqs]
target_rejected_logsumexp_ptr,
# [num_reqs]
draft_rejected_logsumexp_ptr,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_logits, num_blocks]
target_local_argmax_ptr,
target_local_argmax_stride,
# [num_logits, num_blocks]
target_local_max_ptr,
target_local_max_stride,
# [num_logits, num_blocks]
target_local_sumexp_ptr,
target_local_sumexp_stride,
# [num_logits]
draft_sampled_ptr,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr,
draft_logits_stride_0,
draft_logits_stride_1,
# [num_logits, num_blocks]
draft_local_max_ptr,
draft_local_max_stride,
# [num_logits, num_blocks]
draft_local_sumexp_ptr,
draft_local_sumexp_stride,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
# [max_num_reqs]
seed_ptr,
# [num_logits]
pos_ptr,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
):
req_idx = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
num_tokens = end_idx - start_idx
seed = tl.load(seed_ptr + req_state_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
rejected_step = 0
target_lse = 0.0
draft_lse = 0.0
accepted = True
for i in range(num_tokens - 1):
if accepted:
logit_idx = start_idx + i
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
if temp == 0.0:
# Greedy sampling. Accept IFF draft matches target argmax.
# NOTE: Target argmax is stored directly so that resampling
# can be skipped upon rejection.
target_blocks = tl.arange(0, PADDED_VOCAB_NUM_BLOCKS)
target_blocks_mask = target_blocks < vocab_num_blocks
target_local_max = tl.load(
target_local_max_ptr
+ logit_idx * target_local_max_stride
+ target_blocks,
mask=target_blocks_mask,
other=float("-inf"),
)
max_target_block_idx = tl.argmax(target_local_max, axis=0)
target_argmax = tl.load(
target_local_argmax_ptr
+ logit_idx * target_local_argmax_stride
+ max_target_block_idx
)
accepted &= target_argmax == draft_sampled
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_argmax)
else:
target_logit = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + draft_sampled
).to(tl.float32)
draft_logit = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ i * draft_logits_stride_1
+ draft_sampled
).to(tl.float32)
target_lse = _compute_global_lse(
target_local_max_ptr,
target_local_max_stride,
target_local_sumexp_ptr,
target_local_sumexp_stride,
logit_idx,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS,
)
draft_lse = _compute_global_lse(
draft_local_max_ptr,
draft_local_max_stride,
draft_local_sumexp_ptr,
draft_local_sumexp_stride,
logit_idx,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS,
)
target_log_prob = target_logit - target_lse
draft_log_prob = draft_logit - draft_lse
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
# Probability ratio test: p(x) > u * q(x)
# Equivalent log form: log_p(x) > log(u) + log_q(x)
accepted &= target_log_prob > tl.log(u) + draft_log_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step)
tl.store(target_rejected_logsumexp_ptr + req_idx, target_lse)
tl.store(draft_rejected_logsumexp_ptr + req_idx, draft_lse)
@triton.jit
def _resample_kernel(
# [num_reqs, num_blocks]
resampled_local_argmax_ptr,
resampled_local_argmax_stride,
# [num_reqs, num_blocks]
resampled_local_max_ptr,
resampled_local_max_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_reqs]
target_rejected_logsumexp_ptr,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr,
draft_logits_stride_0,
draft_logits_stride_1,
# [num_reqs]
draft_rejected_logsumexp_ptr,
# [num_reqs]
rejected_step_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
expanded_idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
# [max_num_reqs]
seed_ptr,
# [num_logits]
pos_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
resample_idx = tl.load(rejected_step_ptr + req_idx)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
resample_token_idx = start_idx + resample_idx
req_state_idx = tl.load(expanded_idx_mapping_ptr + resample_token_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
is_bonus = resample_token_idx == end_idx - 1
if temp == 0.0 and not is_bonus:
# Greedy + non-bonus token. No resampling needed because
# the target argmax is already in the sampled tensor.
return
block_idx = tl.program_id(1)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block < vocab_size
# Compute the residual logits to resample the rejected token
# from. In the case of no rejections (bonus token), we directly
# use the target logits.
if is_bonus:
residual_logits = tl.load(
target_logits_ptr + resample_token_idx * target_logits_stride + block,
mask=mask,
other=float("-inf"),
).to(tl.float32)
else:
target_logits = tl.load(
target_logits_ptr + resample_token_idx * target_logits_stride + block,
mask=mask,
other=float("-inf"),
).to(tl.float32)
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ resample_idx * draft_logits_stride_1
+ block,
mask=mask,
other=float("-inf"),
).to(tl.float32)
target_lse = tl.load(target_rejected_logsumexp_ptr + req_idx)
draft_lse = tl.load(draft_rejected_logsumexp_ptr + req_idx)
target_log_probs = target_logits - target_lse
draft_log_probs = draft_logits - draft_lse
# Compute the residual: max(p(x) - q(x), 0)
# Equivalent log form: log(max(exp(log_p(x)) - exp(log_q(x)), 0))
# The more numerically stable form is:
# log(max(exp(a) - exp(b), 0)) = a + log(max(1 - exp(b - a), 0))
ratio = tl.exp(draft_log_probs - target_log_probs)
residual_logits = tl.where(
ratio < 1.0,
target_log_probs + tl.log(1 - ratio),
float("-inf"),
).to(tl.float32)
# Resample the rejected/bonus token.
value, idx = gumbel_block_argmax(
residual_logits,
block,
mask,
resample_token_idx,
expanded_idx_mapping_ptr,
temp_ptr,
seed_ptr,
pos_ptr,
None,
0,
APPLY_TEMPERATURE=False,
)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(
resampled_local_argmax_ptr
+ req_idx * resampled_local_argmax_stride
+ block_idx,
token_id,
)
tl.store(
resampled_local_max_ptr + req_idx * resampled_local_max_stride + block_idx,
value,
)
@triton.jit
def _insert_resampled_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
num_sampled_ptr,
# [num_reqs, num_blocks]
resampled_local_argmax_ptr,
resampled_local_argmax_stride,
# [num_reqs, num_blocks]
resampled_local_max_ptr,
resampled_local_max_stride,
resample_num_blocks,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_reqs]
expanded_idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
PADDED_RESAMPLE_NUM_BLOCKS: tl.constexpr,
):
req_idx = tl.program_id(0)
num_sampled = tl.load(num_sampled_ptr + req_idx)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
resample_token_idx = start_idx + num_sampled
req_state_idx = tl.load(expanded_idx_mapping_ptr + resample_token_idx)
# Increment the number of sampled tokens.
tl.store(num_sampled_ptr + req_idx, num_sampled + 1)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
is_bonus = resample_token_idx == end_idx - 1
if temp == 0.0 and not is_bonus:
# Greedy + non-bonus token. The target argmax is already
# in the sampled tensor.
return
# Insert the resampled token.
block = tl.arange(0, PADDED_RESAMPLE_NUM_BLOCKS)
mask = block < resample_num_blocks
resampled_local_max = tl.load(
resampled_local_max_ptr + req_idx * resampled_local_max_stride + block,
mask=mask,
other=float("-inf"),
)
resampled_max_block_idx = tl.argmax(resampled_local_max, axis=0)
resampled = tl.load(
resampled_local_argmax_ptr
+ req_idx * resampled_local_argmax_stride
+ resampled_max_block_idx,
)
tl.store(
sampled_ptr + req_idx * sampled_stride + num_sampled,
resampled,
)
def probabilistic_rejection_sample(
# [num_logits, V]
target_logits: torch.Tensor,
# [max_num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor,
# [num_logits]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
# [num_logits]
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
# [num_logits]
expanded_idx_mapping: torch.Tensor,
# [num_logits]
expanded_local_pos: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
num_logits, vocab_size = target_logits.shape
# Gather draft logits, compute target argmax for greedy, and
# compute per-block LSE and max for non-greedy requests.
VOCAB_BLOCK_SIZE = 8192
vocab_num_blocks = triton.cdiv(vocab_size, VOCAB_BLOCK_SIZE)
padded_vocab_num_blocks = triton.next_power_of_2(vocab_num_blocks)
target_local_argmax = target_logits.new_empty(
num_logits, vocab_num_blocks, dtype=torch.int64
)
target_local_max = target_logits.new_empty(
num_logits, vocab_num_blocks, dtype=torch.float32
)
target_local_sumexp = target_logits.new_empty(
num_logits, vocab_num_blocks, dtype=torch.float32
)
draft_local_max = target_logits.new_empty(
num_logits, vocab_num_blocks, dtype=torch.float32
)
draft_local_sumexp = target_logits.new_empty(
num_logits, vocab_num_blocks, dtype=torch.float32
)
_compute_block_max_and_sumexp_kernel[(num_logits, vocab_num_blocks)](
target_local_argmax,
target_local_argmax.stride(0),
target_local_max,
target_local_max.stride(0),
target_local_sumexp,
target_local_sumexp.stride(0),
draft_local_max,
draft_local_max.stride(0),
draft_local_sumexp,
draft_local_sumexp.stride(0),
target_logits,
target_logits.stride(0),
draft_logits,
draft_logits.stride(0),
draft_logits.stride(1),
expanded_idx_mapping,
expanded_local_pos,
temperature,
vocab_size,
num_speculative_steps,
BLOCK_SIZE=VOCAB_BLOCK_SIZE,
)
# Sample up until the first rejected/bonus token, and store
# the step.
sampled = draft_sampled.new_empty(
num_reqs, num_speculative_steps + 1, dtype=torch.int64
)
num_sampled = sampled.new_empty(num_reqs)
target_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
draft_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
_probabilistic_rejection_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
target_rejected_logsumexp,
draft_rejected_logsumexp,
target_logits,
target_logits.stride(0),
target_local_argmax,
target_local_argmax.stride(0),
target_local_max,
target_local_max.stride(0),
target_local_sumexp,
target_local_sumexp.stride(0),
draft_sampled,
draft_logits,
draft_logits.stride(0),
draft_logits.stride(1),
draft_local_max,
draft_local_max.stride(0),
draft_local_sumexp,
draft_local_sumexp.stride(0),
cu_num_logits,
idx_mapping,
temperature,
seed,
pos,
vocab_num_blocks,
PADDED_VOCAB_NUM_BLOCKS=padded_vocab_num_blocks,
num_warps=1,
)
# Resample the rejected/bonus tokens.
RESAMPLE_BLOCK_SIZE = 1024
resample_num_blocks = triton.cdiv(vocab_size, RESAMPLE_BLOCK_SIZE)
padded_resample_num_blocks = triton.next_power_of_2(resample_num_blocks)
resampled_local_argmax = target_logits.new_empty(
num_reqs, resample_num_blocks, dtype=torch.int64
)
resampled_local_max = target_logits.new_empty(
num_reqs, resample_num_blocks, dtype=torch.float64
)
_resample_kernel[(num_reqs, resample_num_blocks)](
resampled_local_argmax,
resampled_local_argmax.stride(0),
resampled_local_max,
resampled_local_max.stride(0),
target_logits,
target_logits.stride(0),
target_rejected_logsumexp,
draft_logits,
draft_logits.stride(0),
draft_logits.stride(1),
draft_rejected_logsumexp,
num_sampled,
cu_num_logits,
expanded_idx_mapping,
temperature,
seed,
pos,
vocab_size,
BLOCK_SIZE=RESAMPLE_BLOCK_SIZE,
)
# Insert the resampled tokens into the output sampled.
_insert_resampled_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
num_sampled,
resampled_local_argmax,
resampled_local_argmax.stride(0),
resampled_local_max,
resampled_local_max.stride(0),
resample_num_blocks,
cu_num_logits,
expanded_idx_mapping,
temperature,
PADDED_RESAMPLE_NUM_BLOCKS=padded_resample_num_blocks,
)
return sampled, num_sampled

View File

@@ -7,11 +7,13 @@ from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample, tl_rand64
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
probabilistic_rejection_sample,
)
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
compute_synthetic_rejection_sampler_params,
synthetic_rejection_sample,
@@ -75,357 +77,6 @@ def strict_rejection_sample(
return sampled, num_sampled
@triton.jit
def _gather_draft_logits_and_target_argmax_kernel(
local_target_argmax_ptr,
local_target_argmax_stride,
local_target_max_ptr,
local_target_max_stride,
# [num_logits, V]
out_draft_logits_ptr,
out_draft_logits_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr,
draft_logits_stride_0,
draft_logits_stride_1,
# [num_logits]
expanded_idx_mapping_ptr,
# [num_logits]
expanded_local_pos_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
num_speculative_steps,
BLOCK_SIZE: tl.constexpr,
):
logit_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
block_idx = tl.program_id(1)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp == 0.0:
# Greedy sampling. Get the target logits argmax.
target_logits = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
value, idx = tl.max(target_logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ block_idx,
token_id,
)
tl.store(
local_target_max_ptr + logit_idx * local_target_max_stride + block_idx,
value,
)
elif draft_step_idx < num_speculative_steps:
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ draft_step_idx * draft_logits_stride_1
+ block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
tl.store(
out_draft_logits_ptr + logit_idx * out_draft_logits_stride + block_offsets,
draft_logits,
mask=mask,
)
@triton.jit
def _probabilistic_rejection_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
rejected_steps_ptr,
# [num_reqs]
rejected_pos_ptr,
# [num_logits]
draft_sampled_ptr,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_logits, V]
draft_probs_ptr,
draft_probs_stride,
# [num_logits, num_blocks]
local_target_argmax_ptr,
local_target_argmax_stride,
# [num_logits, num_blocks]
local_target_max_ptr,
local_target_max_stride,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
# [max_num_reqs]
seeds_ptr,
NUM_BLOCKS: tl.constexpr,
PADDED_NUM_BLOCKS: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
rejected_step = 0
accepted = True
for i in range(num_tokens - 1):
if accepted:
logit_idx = start_idx + i
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
if temp == 0.0:
# Greedy sampling. Only accept the sampled draft token if
# it exactly matches the target argmax.
block_offsets = tl.arange(0, PADDED_NUM_BLOCKS)
block_mask = block_offsets < NUM_BLOCKS
local_max = tl.load(
local_target_max_ptr
+ logit_idx * local_target_max_stride
+ block_offsets,
mask=block_mask,
other=float("-inf"),
)
max_block = tl.argmax(local_max, axis=0)
target_argmax = tl.load(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ max_block
)
accepted &= target_argmax == draft_sampled
else:
target_prob = tl.load(
target_probs_ptr + logit_idx * target_probs_stride + draft_sampled
).to(tl.float64)
draft_prob = tl.load(
draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled
).to(tl.float64)
pos = tl.load(pos_ptr + logit_idx)
u = tl_rand64(seed, pos, includes_zero=False)
accepted &= target_prob > u * draft_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step)
pos_val = tl.load(pos_ptr + start_idx + rejected_step)
tl.store(rejected_pos_ptr + req_idx, pos_val)
@triton.jit
def _compute_residual_logits_kernel(
# [num_reqs, V]
residual_logits_ptr,
residual_logits_stride,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_logits, V]
draft_probs_ptr,
draft_probs_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_reqs]
rejected_step_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
rejected_logit_idx = start_idx + tl.load(rejected_step_ptr + req_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
if temp == 0.0 or (rejected_logit_idx == end_idx - 1):
# Greedy sampling / bonus token. In either case, use the
# target logits directly to reduce numerical error.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=float("-inf"),
)
else:
target_probs = tl.load(
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
draft_probs = tl.load(
draft_probs_ptr + rejected_logit_idx * draft_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
residual_logits = tl.log(residual_probs)
tl.store(
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
residual_logits,
mask=mask,
)
def probabilistic_rejection_sample(
# [num_logits, V]
target_logits: torch.Tensor,
# [max_num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor,
# [num_logits]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
# [num_logits]
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
# [num_logits]
expanded_idx_mapping: torch.Tensor,
# [num_logits]
expanded_local_pos: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
num_logits, vocab_size = target_logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# Gather draft logits and target argmax for greedy sampling.
gathered_draft_logits = target_logits.new_empty(target_logits.shape)
local_target_argmax = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.int64
)
local_target_max = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.float32
)
_gather_draft_logits_and_target_argmax_kernel[(num_logits, num_blocks)](
local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
gathered_draft_logits,
gathered_draft_logits.stride(0),
target_logits,
target_logits.stride(0),
draft_logits,
draft_logits.stride(0),
draft_logits.stride(1),
expanded_idx_mapping,
expanded_local_pos,
temperature,
vocab_size,
num_speculative_steps,
BLOCK_SIZE=BLOCK_SIZE,
)
# Compute target and draft probs.
target_probs = torch.softmax(target_logits, dim=-1)
draft_probs = torch.softmax(gathered_draft_logits, dim=-1)
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
sampled = draft_sampled.new_empty(
num_reqs, num_speculative_steps + 1, dtype=torch.int64
)
# [num_reqs]
rejected_steps = sampled.new_empty(num_reqs)
# [num_reqs]
rejected_pos = pos.new_empty(num_reqs)
_probabilistic_rejection_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
rejected_steps,
rejected_pos,
draft_sampled,
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
cu_num_logits,
pos,
idx_mapping,
temperature,
seed,
num_warps=1,
NUM_BLOCKS=num_blocks,
PADDED_NUM_BLOCKS=triton.next_power_of_2(num_blocks),
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits = target_logits.new_empty(num_reqs, vocab_size)
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
residual_logits,
residual_logits.stride(0),
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
target_logits,
target_logits.stride(0),
rejected_steps,
cu_num_logits,
idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
# Gumbel sample tokens from the residual distribution.
resampled = gumbel_sample(
residual_logits,
idx_mapping,
temperature,
seed,
rejected_pos,
apply_temperature=False,
)
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
return sampled, rejected_steps + 1
@triton.jit
def _flatten_sampled_kernel(
# [num_logits]