[Model Runner V2] Fuse probabilistic rejection sample kernels (#38496)
Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
@@ -100,11 +100,13 @@ steps:
|
||||
- vllm/v1/worker/gpu/
|
||||
- vllm/v1/worker/gpu_worker.py
|
||||
- tests/v1/spec_decode/test_max_len.py
|
||||
- tests/v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
|
||||
- tests/v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- tests/v1/e2e/spec_decode/test_spec_decode.py
|
||||
commands:
|
||||
- set -x
|
||||
- export VLLM_USE_V2_MODEL_RUNNER=1
|
||||
- pytest -v -s v1/spec_decode/test_max_len.py -k "eagle or mtp"
|
||||
- pytest -v -s v1/spec_decode/test_probabilistic_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/spec_decode/test_synthetic_rejection_sampler_utils.py
|
||||
- pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle or mtp"
|
||||
|
||||
@@ -0,0 +1,215 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
|
||||
probabilistic_rejection_sample,
|
||||
)
|
||||
|
||||
VOCAB_SIZE = 4096
|
||||
|
||||
# Skip if no CUDA - Triton kernel requires GPU
|
||||
pytest.importorskip("triton")
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA required for rejection sampler tests", allow_module_level=True)
|
||||
|
||||
|
||||
def _build_rejection_sample_inputs(
|
||||
target_logits_1d: torch.Tensor,
|
||||
draft_logits_1d: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
temperature: float,
|
||||
num_trials: int,
|
||||
) -> dict:
|
||||
device = target_logits_1d.device
|
||||
vocab_size = target_logits_1d.shape[0]
|
||||
K = num_speculative_steps
|
||||
num_logits = num_trials * (K + 1)
|
||||
|
||||
target_logits = target_logits_1d.unsqueeze(0).expand(num_logits, -1).contiguous()
|
||||
draft_logits = (
|
||||
draft_logits_1d.view(1, 1, vocab_size).expand(num_trials, K, -1).contiguous()
|
||||
)
|
||||
|
||||
draft_probs = torch.softmax(draft_logits_1d, dim=0)
|
||||
draft_tokens = torch.multinomial(
|
||||
draft_probs.expand(num_trials, -1), K, replacement=True
|
||||
)
|
||||
draft_sampled_2d = torch.zeros(num_trials, K + 1, dtype=torch.int64, device=device)
|
||||
draft_sampled_2d[:, 1:] = draft_tokens
|
||||
draft_sampled = draft_sampled_2d.reshape(-1)
|
||||
|
||||
cu_num_logits = torch.arange(num_trials + 1, dtype=torch.int32, device=device) * (
|
||||
K + 1
|
||||
)
|
||||
pos = torch.arange(num_logits, dtype=torch.int32, device=device)
|
||||
idx_mapping = torch.arange(num_trials, dtype=torch.int32, device=device)
|
||||
expanded_idx_mapping = torch.arange(
|
||||
num_trials, dtype=torch.int32, device=device
|
||||
).repeat_interleave(K + 1)
|
||||
expanded_local_pos = torch.arange(K + 1, dtype=torch.int32, device=device).repeat(
|
||||
num_trials
|
||||
)
|
||||
temp_tensor = torch.full(
|
||||
(num_trials,), temperature, dtype=torch.float32, device=device
|
||||
)
|
||||
seed = torch.arange(num_trials, dtype=torch.int64, device=device)
|
||||
|
||||
return dict(
|
||||
target_logits=target_logits,
|
||||
draft_logits=draft_logits,
|
||||
draft_sampled=draft_sampled,
|
||||
cu_num_logits=cu_num_logits,
|
||||
pos=pos,
|
||||
idx_mapping=idx_mapping,
|
||||
expanded_idx_mapping=expanded_idx_mapping,
|
||||
expanded_local_pos=expanded_local_pos,
|
||||
temperature=temp_tensor,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
|
||||
def _assert_distribution_match(
|
||||
sampled_tokens: torch.Tensor,
|
||||
target_probs: torch.Tensor,
|
||||
device: str,
|
||||
label: str = "",
|
||||
min_expected: float = 5.0,
|
||||
):
|
||||
"""
|
||||
Assert sampled tokens match the target distribution via a
|
||||
chi-squared goodness-of-fit test. This is done by computing
|
||||
observed vs expected token counts (target_probs * num_samples),
|
||||
then checking that the chi-squared statistic is below a conservative
|
||||
threshold. The threshold is set at df + 10*sqrt(2*df), which
|
||||
corresponds to ~10 sigma under the chi-squared distribution's
|
||||
normal approximation, effectively disallowing false positives.
|
||||
|
||||
NOTE: Tokens with expected count < min_expected are merged into
|
||||
a single "other" bin to minimize chi-squared noise.
|
||||
"""
|
||||
num_samples = sampled_tokens.shape[0]
|
||||
vocab_size = target_probs.shape[0]
|
||||
|
||||
observed = torch.zeros(vocab_size, device=device, dtype=torch.float32)
|
||||
observed.scatter_add_(0, sampled_tokens, torch.ones(num_samples, device=device))
|
||||
expected = target_probs * num_samples
|
||||
|
||||
sufficient = expected >= min_expected
|
||||
obs_main = observed[sufficient]
|
||||
exp_main = expected[sufficient]
|
||||
|
||||
obs_other = observed[~sufficient].sum().unsqueeze(0)
|
||||
exp_other = expected[~sufficient].sum().unsqueeze(0)
|
||||
|
||||
if exp_other.item() >= min_expected:
|
||||
obs_all = torch.cat([obs_main, obs_other])
|
||||
exp_all = torch.cat([exp_main, exp_other])
|
||||
else:
|
||||
obs_all = obs_main
|
||||
exp_all = exp_main
|
||||
|
||||
chi2 = ((obs_all - exp_all) ** 2 / exp_all).sum().item()
|
||||
df = obs_all.shape[0] - 1
|
||||
if df < 1:
|
||||
# All samples were merged into < 2 bins, which is too
|
||||
# few to evaluate.
|
||||
return
|
||||
|
||||
threshold = df + 10 * math.sqrt(2 * df)
|
||||
prefix = f"[{label}] " if label else ""
|
||||
assert chi2 < threshold, (
|
||||
f"{prefix}Chi-squared test failed: chi2={chi2:.1f}, "
|
||||
f"df={df}, threshold={threshold:.1f}. "
|
||||
f"Output distribution does not match target distribution."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"num_speculative_steps,temperature",
|
||||
[
|
||||
(1, 0.6),
|
||||
(3, 0.6),
|
||||
(1, 1.0),
|
||||
(3, 1.0),
|
||||
],
|
||||
)
|
||||
def test_stochastic_rejection_sample(num_speculative_steps: int, temperature: float):
|
||||
"""
|
||||
Verify that rejection sampling produces the target distribution.
|
||||
This is done by simulating many independent trials of speculative
|
||||
decoding (from a fixed target and draft distribution). We then
|
||||
run rejection sample on all of the trials (requests), and verify
|
||||
that the sampled tokens at every position follow the target
|
||||
distribution p(x).
|
||||
"""
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_trials = 10 * VOCAB_SIZE
|
||||
|
||||
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
|
||||
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
|
||||
|
||||
if temperature > 0:
|
||||
target_logits_1d /= temperature
|
||||
draft_logits_1d /= temperature
|
||||
|
||||
inputs = _build_rejection_sample_inputs(
|
||||
target_logits_1d,
|
||||
draft_logits_1d,
|
||||
num_speculative_steps,
|
||||
temperature=temperature,
|
||||
num_trials=num_trials,
|
||||
)
|
||||
|
||||
sampled, num_sampled = probabilistic_rejection_sample(
|
||||
**inputs, num_speculative_steps=num_speculative_steps
|
||||
)
|
||||
|
||||
target_probs = torch.softmax(target_logits_1d, dim=0)
|
||||
for pos in range(num_speculative_steps + 1):
|
||||
accepted_mask = num_sampled >= pos + 1
|
||||
_assert_distribution_match(
|
||||
sampled[accepted_mask, pos], target_probs, device, label=f"position {pos}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_speculative_steps", [1, 3])
|
||||
def test_greedy_rejection_sample(num_speculative_steps: int):
|
||||
"""
|
||||
Verify that greedy (temperature=0) always outputs the target argmax
|
||||
at every accepted position.
|
||||
"""
|
||||
|
||||
torch.manual_seed(42)
|
||||
device = "cuda"
|
||||
num_trials = 10 * VOCAB_SIZE
|
||||
|
||||
target_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
|
||||
draft_logits_1d = torch.randn(VOCAB_SIZE, device=device, dtype=torch.float32)
|
||||
|
||||
inputs = _build_rejection_sample_inputs(
|
||||
target_logits_1d,
|
||||
draft_logits_1d,
|
||||
num_speculative_steps,
|
||||
temperature=0.0,
|
||||
num_trials=num_trials,
|
||||
)
|
||||
|
||||
sampled, num_sampled = probabilistic_rejection_sample(
|
||||
**inputs, num_speculative_steps=num_speculative_steps
|
||||
)
|
||||
|
||||
target_argmax = target_logits_1d.argmax().item()
|
||||
|
||||
steps = torch.arange(num_speculative_steps + 1, device=device).unsqueeze(0)
|
||||
accepted_mask = steps < num_sampled.unsqueeze(1)
|
||||
|
||||
assert (sampled[accepted_mask] == target_argmax).all(), (
|
||||
"Greedy sampling produced tokens that are not the target argmax"
|
||||
)
|
||||
@@ -65,36 +65,20 @@ def tl_rand64(seed, offset, includes_zero: tl.constexpr):
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
processed_logits_ptr,
|
||||
processed_logits_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
def gumbel_block_argmax(
|
||||
logits,
|
||||
block,
|
||||
mask,
|
||||
token_idx,
|
||||
expanded_idx_mapping_ptr,
|
||||
temp_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
processed_logits_ptr,
|
||||
processed_logits_stride,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + token_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
if temp != 0.0 and APPLY_TEMPERATURE:
|
||||
# Apply temperature.
|
||||
@@ -102,8 +86,8 @@ def _gumbel_sample_kernel(
|
||||
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
|
||||
logits = logits / temp
|
||||
|
||||
# Store the temperature-applied logits.
|
||||
if processed_logits_ptr is not None:
|
||||
# Store the temperature-applied logits.
|
||||
tl.store(
|
||||
processed_logits_ptr + req_state_idx * processed_logits_stride + block,
|
||||
logits,
|
||||
@@ -126,6 +110,51 @@ def _gumbel_sample_kernel(
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
return value, idx
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
local_argmax_stride,
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
processed_logits_ptr,
|
||||
processed_logits_stride,
|
||||
logits_ptr,
|
||||
logits_stride,
|
||||
expanded_idx_mapping_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
APPLY_TEMPERATURE: tl.constexpr,
|
||||
):
|
||||
token_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
logits = tl.load(
|
||||
logits_ptr + token_idx * logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
logits = logits.to(tl.float32)
|
||||
|
||||
value, idx = gumbel_block_argmax(
|
||||
logits,
|
||||
block,
|
||||
mask,
|
||||
token_idx,
|
||||
expanded_idx_mapping_ptr,
|
||||
temp_ptr,
|
||||
seeds_ptr,
|
||||
pos_ptr,
|
||||
processed_logits_ptr,
|
||||
processed_logits_stride,
|
||||
APPLY_TEMPERATURE=APPLY_TEMPERATURE,
|
||||
)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(local_argmax_ptr + token_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + token_idx * local_max_stride + block_idx, value)
|
||||
|
||||
@@ -0,0 +1,612 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_block_argmax, tl_rand64
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_block_max_and_sumexp(logits):
|
||||
block_max = tl.max(logits, axis=0)
|
||||
block_sumexp = tl.where(
|
||||
block_max > float("-inf"),
|
||||
tl.sum(tl.exp(logits - block_max)),
|
||||
0.0,
|
||||
)
|
||||
return block_max, block_sumexp
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_global_lse(
|
||||
local_max_ptr,
|
||||
local_max_stride,
|
||||
local_sumexp_ptr,
|
||||
local_sumexp_stride,
|
||||
logit_idx,
|
||||
vocab_num_blocks,
|
||||
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
|
||||
):
|
||||
blocks = tl.arange(0, PADDED_VOCAB_NUM_BLOCKS)
|
||||
blocks_mask = blocks < vocab_num_blocks
|
||||
maxes = tl.load(
|
||||
local_max_ptr + logit_idx * local_max_stride + blocks,
|
||||
mask=blocks_mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
sumexps = tl.load(
|
||||
local_sumexp_ptr + logit_idx * local_sumexp_stride + blocks,
|
||||
mask=blocks_mask,
|
||||
other=0.0,
|
||||
)
|
||||
global_max = tl.max(maxes, axis=0)
|
||||
global_lse = global_max + tl.log(tl.sum(sumexps * tl.exp(maxes - global_max)))
|
||||
return global_lse
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_block_max_and_sumexp_kernel(
|
||||
# [num_logits, num_blocks]
|
||||
target_local_argmax_ptr,
|
||||
target_local_argmax_stride,
|
||||
# [num_logits, num_blocks]
|
||||
target_local_max_ptr,
|
||||
target_local_max_stride,
|
||||
# [num_logits, num_blocks]
|
||||
target_local_sumexp_ptr,
|
||||
target_local_sumexp_stride,
|
||||
# [num_logits, num_blocks]
|
||||
draft_local_max_ptr,
|
||||
draft_local_max_stride,
|
||||
# [num_logits, num_blocks]
|
||||
draft_local_sumexp_ptr,
|
||||
draft_local_sumexp_stride,
|
||||
# [num_logits, V]
|
||||
target_logits_ptr,
|
||||
target_logits_stride,
|
||||
# [max_num_reqs, num_speculative_steps, V]
|
||||
draft_logits_ptr,
|
||||
draft_logits_stride_0,
|
||||
draft_logits_stride_1,
|
||||
# [num_logits]
|
||||
expanded_idx_mapping_ptr,
|
||||
# [num_logits]
|
||||
expanded_local_pos_ptr,
|
||||
# [max_num_reqs]
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
num_speculative_steps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
|
||||
if draft_step_idx >= num_speculative_steps:
|
||||
# Bonus token. Max/argmax and summed exponentials are not needed.
|
||||
return
|
||||
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_offsets < vocab_size
|
||||
|
||||
if temp == 0.0:
|
||||
# Greedy sampling. Only the target max/argmax are needed.
|
||||
target_logits = tl.load(
|
||||
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
value, idx = tl.max(target_logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(
|
||||
target_local_argmax_ptr
|
||||
+ logit_idx * target_local_argmax_stride
|
||||
+ block_idx,
|
||||
token_id,
|
||||
)
|
||||
tl.store(
|
||||
target_local_max_ptr + logit_idx * target_local_max_stride + block_idx,
|
||||
value,
|
||||
)
|
||||
else:
|
||||
# Get local draft max and summed exponentials.
|
||||
draft_logits = tl.load(
|
||||
draft_logits_ptr
|
||||
+ req_state_idx * draft_logits_stride_0
|
||||
+ draft_step_idx * draft_logits_stride_1
|
||||
+ block_offsets,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
draft_max, draft_sumexp = _compute_block_max_and_sumexp(draft_logits)
|
||||
tl.store(
|
||||
draft_local_max_ptr + logit_idx * draft_local_max_stride + block_idx,
|
||||
draft_max,
|
||||
)
|
||||
tl.store(
|
||||
draft_local_sumexp_ptr + logit_idx * draft_local_sumexp_stride + block_idx,
|
||||
draft_sumexp,
|
||||
)
|
||||
# Get local target max and summed exponentials.
|
||||
target_logits = tl.load(
|
||||
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
target_max, target_sumexp = _compute_block_max_and_sumexp(target_logits)
|
||||
tl.store(
|
||||
target_local_max_ptr + logit_idx * target_local_max_stride + block_idx,
|
||||
target_max,
|
||||
)
|
||||
tl.store(
|
||||
target_local_sumexp_ptr
|
||||
+ logit_idx * target_local_sumexp_stride
|
||||
+ block_idx,
|
||||
target_sumexp,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _probabilistic_rejection_kernel(
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_ptr,
|
||||
sampled_stride,
|
||||
# [num_reqs]
|
||||
rejected_steps_ptr,
|
||||
# [num_reqs]
|
||||
target_rejected_logsumexp_ptr,
|
||||
# [num_reqs]
|
||||
draft_rejected_logsumexp_ptr,
|
||||
# [num_logits, V]
|
||||
target_logits_ptr,
|
||||
target_logits_stride,
|
||||
# [num_logits, num_blocks]
|
||||
target_local_argmax_ptr,
|
||||
target_local_argmax_stride,
|
||||
# [num_logits, num_blocks]
|
||||
target_local_max_ptr,
|
||||
target_local_max_stride,
|
||||
# [num_logits, num_blocks]
|
||||
target_local_sumexp_ptr,
|
||||
target_local_sumexp_stride,
|
||||
# [num_logits]
|
||||
draft_sampled_ptr,
|
||||
# [max_num_reqs, num_speculative_steps, V]
|
||||
draft_logits_ptr,
|
||||
draft_logits_stride_0,
|
||||
draft_logits_stride_1,
|
||||
# [num_logits, num_blocks]
|
||||
draft_local_max_ptr,
|
||||
draft_local_max_stride,
|
||||
# [num_logits, num_blocks]
|
||||
draft_local_sumexp_ptr,
|
||||
draft_local_sumexp_stride,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits_ptr,
|
||||
# [num_reqs]
|
||||
idx_mapping_ptr,
|
||||
# [max_num_reqs]
|
||||
temp_ptr,
|
||||
# [max_num_reqs]
|
||||
seed_ptr,
|
||||
# [num_logits]
|
||||
pos_ptr,
|
||||
vocab_num_blocks,
|
||||
PADDED_VOCAB_NUM_BLOCKS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
num_tokens = end_idx - start_idx
|
||||
seed = tl.load(seed_ptr + req_state_idx)
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
|
||||
rejected_step = 0
|
||||
target_lse = 0.0
|
||||
draft_lse = 0.0
|
||||
accepted = True
|
||||
for i in range(num_tokens - 1):
|
||||
if accepted:
|
||||
logit_idx = start_idx + i
|
||||
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
|
||||
if temp == 0.0:
|
||||
# Greedy sampling. Accept IFF draft matches target argmax.
|
||||
# NOTE: Target argmax is stored directly so that resampling
|
||||
# can be skipped upon rejection.
|
||||
target_blocks = tl.arange(0, PADDED_VOCAB_NUM_BLOCKS)
|
||||
target_blocks_mask = target_blocks < vocab_num_blocks
|
||||
target_local_max = tl.load(
|
||||
target_local_max_ptr
|
||||
+ logit_idx * target_local_max_stride
|
||||
+ target_blocks,
|
||||
mask=target_blocks_mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
max_target_block_idx = tl.argmax(target_local_max, axis=0)
|
||||
target_argmax = tl.load(
|
||||
target_local_argmax_ptr
|
||||
+ logit_idx * target_local_argmax_stride
|
||||
+ max_target_block_idx
|
||||
)
|
||||
accepted &= target_argmax == draft_sampled
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, target_argmax)
|
||||
else:
|
||||
target_logit = tl.load(
|
||||
target_logits_ptr + logit_idx * target_logits_stride + draft_sampled
|
||||
).to(tl.float32)
|
||||
draft_logit = tl.load(
|
||||
draft_logits_ptr
|
||||
+ req_state_idx * draft_logits_stride_0
|
||||
+ i * draft_logits_stride_1
|
||||
+ draft_sampled
|
||||
).to(tl.float32)
|
||||
target_lse = _compute_global_lse(
|
||||
target_local_max_ptr,
|
||||
target_local_max_stride,
|
||||
target_local_sumexp_ptr,
|
||||
target_local_sumexp_stride,
|
||||
logit_idx,
|
||||
vocab_num_blocks,
|
||||
PADDED_VOCAB_NUM_BLOCKS,
|
||||
)
|
||||
draft_lse = _compute_global_lse(
|
||||
draft_local_max_ptr,
|
||||
draft_local_max_stride,
|
||||
draft_local_sumexp_ptr,
|
||||
draft_local_sumexp_stride,
|
||||
logit_idx,
|
||||
vocab_num_blocks,
|
||||
PADDED_VOCAB_NUM_BLOCKS,
|
||||
)
|
||||
target_log_prob = target_logit - target_lse
|
||||
draft_log_prob = draft_logit - draft_lse
|
||||
pos = tl.load(pos_ptr + logit_idx)
|
||||
u = tl_rand64(seed, pos, includes_zero=False)
|
||||
# Probability ratio test: p(x) > u * q(x)
|
||||
# Equivalent log form: log_p(x) > log(u) + log_q(x)
|
||||
accepted &= target_log_prob > tl.log(u) + draft_log_prob
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
|
||||
rejected_step += accepted
|
||||
tl.store(rejected_steps_ptr + req_idx, rejected_step)
|
||||
tl.store(target_rejected_logsumexp_ptr + req_idx, target_lse)
|
||||
tl.store(draft_rejected_logsumexp_ptr + req_idx, draft_lse)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _resample_kernel(
|
||||
# [num_reqs, num_blocks]
|
||||
resampled_local_argmax_ptr,
|
||||
resampled_local_argmax_stride,
|
||||
# [num_reqs, num_blocks]
|
||||
resampled_local_max_ptr,
|
||||
resampled_local_max_stride,
|
||||
# [num_logits, V]
|
||||
target_logits_ptr,
|
||||
target_logits_stride,
|
||||
# [num_reqs]
|
||||
target_rejected_logsumexp_ptr,
|
||||
# [max_num_reqs, num_speculative_steps, V]
|
||||
draft_logits_ptr,
|
||||
draft_logits_stride_0,
|
||||
draft_logits_stride_1,
|
||||
# [num_reqs]
|
||||
draft_rejected_logsumexp_ptr,
|
||||
# [num_reqs]
|
||||
rejected_step_ptr,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits_ptr,
|
||||
# [num_logits]
|
||||
expanded_idx_mapping_ptr,
|
||||
# [max_num_reqs]
|
||||
temp_ptr,
|
||||
# [max_num_reqs]
|
||||
seed_ptr,
|
||||
# [num_logits]
|
||||
pos_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
resample_idx = tl.load(rejected_step_ptr + req_idx)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
resample_token_idx = start_idx + resample_idx
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + resample_token_idx)
|
||||
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
is_bonus = resample_token_idx == end_idx - 1
|
||||
if temp == 0.0 and not is_bonus:
|
||||
# Greedy + non-bonus token. No resampling needed because
|
||||
# the target argmax is already in the sampled tensor.
|
||||
return
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block < vocab_size
|
||||
|
||||
# Compute the residual logits to resample the rejected token
|
||||
# from. In the case of no rejections (bonus token), we directly
|
||||
# use the target logits.
|
||||
if is_bonus:
|
||||
residual_logits = tl.load(
|
||||
target_logits_ptr + resample_token_idx * target_logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
else:
|
||||
target_logits = tl.load(
|
||||
target_logits_ptr + resample_token_idx * target_logits_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
draft_logits = tl.load(
|
||||
draft_logits_ptr
|
||||
+ req_state_idx * draft_logits_stride_0
|
||||
+ resample_idx * draft_logits_stride_1
|
||||
+ block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
target_lse = tl.load(target_rejected_logsumexp_ptr + req_idx)
|
||||
draft_lse = tl.load(draft_rejected_logsumexp_ptr + req_idx)
|
||||
target_log_probs = target_logits - target_lse
|
||||
draft_log_probs = draft_logits - draft_lse
|
||||
# Compute the residual: max(p(x) - q(x), 0)
|
||||
# Equivalent log form: log(max(exp(log_p(x)) - exp(log_q(x)), 0))
|
||||
# The more numerically stable form is:
|
||||
# log(max(exp(a) - exp(b), 0)) = a + log(max(1 - exp(b - a), 0))
|
||||
ratio = tl.exp(draft_log_probs - target_log_probs)
|
||||
residual_logits = tl.where(
|
||||
ratio < 1.0,
|
||||
target_log_probs + tl.log(1 - ratio),
|
||||
float("-inf"),
|
||||
).to(tl.float32)
|
||||
|
||||
# Resample the rejected/bonus token.
|
||||
value, idx = gumbel_block_argmax(
|
||||
residual_logits,
|
||||
block,
|
||||
mask,
|
||||
resample_token_idx,
|
||||
expanded_idx_mapping_ptr,
|
||||
temp_ptr,
|
||||
seed_ptr,
|
||||
pos_ptr,
|
||||
None,
|
||||
0,
|
||||
APPLY_TEMPERATURE=False,
|
||||
)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(
|
||||
resampled_local_argmax_ptr
|
||||
+ req_idx * resampled_local_argmax_stride
|
||||
+ block_idx,
|
||||
token_id,
|
||||
)
|
||||
tl.store(
|
||||
resampled_local_max_ptr + req_idx * resampled_local_max_stride + block_idx,
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _insert_resampled_kernel(
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_ptr,
|
||||
sampled_stride,
|
||||
# [num_reqs]
|
||||
num_sampled_ptr,
|
||||
# [num_reqs, num_blocks]
|
||||
resampled_local_argmax_ptr,
|
||||
resampled_local_argmax_stride,
|
||||
# [num_reqs, num_blocks]
|
||||
resampled_local_max_ptr,
|
||||
resampled_local_max_stride,
|
||||
resample_num_blocks,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits_ptr,
|
||||
# [num_reqs]
|
||||
expanded_idx_mapping_ptr,
|
||||
# [max_num_reqs]
|
||||
temp_ptr,
|
||||
PADDED_RESAMPLE_NUM_BLOCKS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
num_sampled = tl.load(num_sampled_ptr + req_idx)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
resample_token_idx = start_idx + num_sampled
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + resample_token_idx)
|
||||
|
||||
# Increment the number of sampled tokens.
|
||||
tl.store(num_sampled_ptr + req_idx, num_sampled + 1)
|
||||
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
is_bonus = resample_token_idx == end_idx - 1
|
||||
if temp == 0.0 and not is_bonus:
|
||||
# Greedy + non-bonus token. The target argmax is already
|
||||
# in the sampled tensor.
|
||||
return
|
||||
|
||||
# Insert the resampled token.
|
||||
block = tl.arange(0, PADDED_RESAMPLE_NUM_BLOCKS)
|
||||
mask = block < resample_num_blocks
|
||||
resampled_local_max = tl.load(
|
||||
resampled_local_max_ptr + req_idx * resampled_local_max_stride + block,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
resampled_max_block_idx = tl.argmax(resampled_local_max, axis=0)
|
||||
resampled = tl.load(
|
||||
resampled_local_argmax_ptr
|
||||
+ req_idx * resampled_local_argmax_stride
|
||||
+ resampled_max_block_idx,
|
||||
)
|
||||
tl.store(
|
||||
sampled_ptr + req_idx * sampled_stride + num_sampled,
|
||||
resampled,
|
||||
)
|
||||
|
||||
|
||||
def probabilistic_rejection_sample(
|
||||
# [num_logits, V]
|
||||
target_logits: torch.Tensor,
|
||||
# [max_num_reqs, num_speculative_steps, V]
|
||||
draft_logits: torch.Tensor,
|
||||
# [num_logits]
|
||||
draft_sampled: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
# [num_logits]
|
||||
pos: torch.Tensor,
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [num_logits]
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
# [num_logits]
|
||||
expanded_local_pos: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seed: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
num_logits, vocab_size = target_logits.shape
|
||||
|
||||
# Gather draft logits, compute target argmax for greedy, and
|
||||
# compute per-block LSE and max for non-greedy requests.
|
||||
VOCAB_BLOCK_SIZE = 8192
|
||||
vocab_num_blocks = triton.cdiv(vocab_size, VOCAB_BLOCK_SIZE)
|
||||
padded_vocab_num_blocks = triton.next_power_of_2(vocab_num_blocks)
|
||||
target_local_argmax = target_logits.new_empty(
|
||||
num_logits, vocab_num_blocks, dtype=torch.int64
|
||||
)
|
||||
target_local_max = target_logits.new_empty(
|
||||
num_logits, vocab_num_blocks, dtype=torch.float32
|
||||
)
|
||||
target_local_sumexp = target_logits.new_empty(
|
||||
num_logits, vocab_num_blocks, dtype=torch.float32
|
||||
)
|
||||
draft_local_max = target_logits.new_empty(
|
||||
num_logits, vocab_num_blocks, dtype=torch.float32
|
||||
)
|
||||
draft_local_sumexp = target_logits.new_empty(
|
||||
num_logits, vocab_num_blocks, dtype=torch.float32
|
||||
)
|
||||
_compute_block_max_and_sumexp_kernel[(num_logits, vocab_num_blocks)](
|
||||
target_local_argmax,
|
||||
target_local_argmax.stride(0),
|
||||
target_local_max,
|
||||
target_local_max.stride(0),
|
||||
target_local_sumexp,
|
||||
target_local_sumexp.stride(0),
|
||||
draft_local_max,
|
||||
draft_local_max.stride(0),
|
||||
draft_local_sumexp,
|
||||
draft_local_sumexp.stride(0),
|
||||
target_logits,
|
||||
target_logits.stride(0),
|
||||
draft_logits,
|
||||
draft_logits.stride(0),
|
||||
draft_logits.stride(1),
|
||||
expanded_idx_mapping,
|
||||
expanded_local_pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
num_speculative_steps,
|
||||
BLOCK_SIZE=VOCAB_BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Sample up until the first rejected/bonus token, and store
|
||||
# the step.
|
||||
sampled = draft_sampled.new_empty(
|
||||
num_reqs, num_speculative_steps + 1, dtype=torch.int64
|
||||
)
|
||||
num_sampled = sampled.new_empty(num_reqs)
|
||||
target_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
|
||||
draft_rejected_logsumexp = target_logits.new_empty(num_reqs, dtype=torch.float32)
|
||||
_probabilistic_rejection_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
target_rejected_logsumexp,
|
||||
draft_rejected_logsumexp,
|
||||
target_logits,
|
||||
target_logits.stride(0),
|
||||
target_local_argmax,
|
||||
target_local_argmax.stride(0),
|
||||
target_local_max,
|
||||
target_local_max.stride(0),
|
||||
target_local_sumexp,
|
||||
target_local_sumexp.stride(0),
|
||||
draft_sampled,
|
||||
draft_logits,
|
||||
draft_logits.stride(0),
|
||||
draft_logits.stride(1),
|
||||
draft_local_max,
|
||||
draft_local_max.stride(0),
|
||||
draft_local_sumexp,
|
||||
draft_local_sumexp.stride(0),
|
||||
cu_num_logits,
|
||||
idx_mapping,
|
||||
temperature,
|
||||
seed,
|
||||
pos,
|
||||
vocab_num_blocks,
|
||||
PADDED_VOCAB_NUM_BLOCKS=padded_vocab_num_blocks,
|
||||
num_warps=1,
|
||||
)
|
||||
|
||||
# Resample the rejected/bonus tokens.
|
||||
RESAMPLE_BLOCK_SIZE = 1024
|
||||
resample_num_blocks = triton.cdiv(vocab_size, RESAMPLE_BLOCK_SIZE)
|
||||
padded_resample_num_blocks = triton.next_power_of_2(resample_num_blocks)
|
||||
resampled_local_argmax = target_logits.new_empty(
|
||||
num_reqs, resample_num_blocks, dtype=torch.int64
|
||||
)
|
||||
resampled_local_max = target_logits.new_empty(
|
||||
num_reqs, resample_num_blocks, dtype=torch.float64
|
||||
)
|
||||
_resample_kernel[(num_reqs, resample_num_blocks)](
|
||||
resampled_local_argmax,
|
||||
resampled_local_argmax.stride(0),
|
||||
resampled_local_max,
|
||||
resampled_local_max.stride(0),
|
||||
target_logits,
|
||||
target_logits.stride(0),
|
||||
target_rejected_logsumexp,
|
||||
draft_logits,
|
||||
draft_logits.stride(0),
|
||||
draft_logits.stride(1),
|
||||
draft_rejected_logsumexp,
|
||||
num_sampled,
|
||||
cu_num_logits,
|
||||
expanded_idx_mapping,
|
||||
temperature,
|
||||
seed,
|
||||
pos,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=RESAMPLE_BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Insert the resampled tokens into the output sampled.
|
||||
_insert_resampled_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
num_sampled,
|
||||
resampled_local_argmax,
|
||||
resampled_local_argmax.stride(0),
|
||||
resampled_local_max,
|
||||
resampled_local_max.stride(0),
|
||||
resample_num_blocks,
|
||||
cu_num_logits,
|
||||
expanded_idx_mapping,
|
||||
temperature,
|
||||
PADDED_RESAMPLE_NUM_BLOCKS=padded_resample_num_blocks,
|
||||
)
|
||||
return sampled, num_sampled
|
||||
@@ -7,11 +7,13 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample, tl_rand64
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS
|
||||
from vllm.v1.worker.gpu.spec_decode.probabilistic_rejection_sampler_utils import (
|
||||
probabilistic_rejection_sample,
|
||||
)
|
||||
from vllm.v1.worker.gpu.spec_decode.synthetic_rejection_sampler_utils import (
|
||||
compute_synthetic_rejection_sampler_params,
|
||||
synthetic_rejection_sample,
|
||||
@@ -75,357 +77,6 @@ def strict_rejection_sample(
|
||||
return sampled, num_sampled
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gather_draft_logits_and_target_argmax_kernel(
|
||||
local_target_argmax_ptr,
|
||||
local_target_argmax_stride,
|
||||
local_target_max_ptr,
|
||||
local_target_max_stride,
|
||||
# [num_logits, V]
|
||||
out_draft_logits_ptr,
|
||||
out_draft_logits_stride,
|
||||
# [num_logits, V]
|
||||
target_logits_ptr,
|
||||
target_logits_stride,
|
||||
# [max_num_reqs, num_speculative_steps, V]
|
||||
draft_logits_ptr,
|
||||
draft_logits_stride_0,
|
||||
draft_logits_stride_1,
|
||||
# [num_logits]
|
||||
expanded_idx_mapping_ptr,
|
||||
# [num_logits]
|
||||
expanded_local_pos_ptr,
|
||||
# [max_num_reqs]
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
num_speculative_steps,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
logit_idx = tl.program_id(0)
|
||||
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
|
||||
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
|
||||
|
||||
block_idx = tl.program_id(1)
|
||||
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_offsets < vocab_size
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
|
||||
if temp == 0.0:
|
||||
# Greedy sampling. Get the target logits argmax.
|
||||
target_logits = tl.load(
|
||||
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
value, idx = tl.max(target_logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
tl.store(
|
||||
local_target_argmax_ptr
|
||||
+ logit_idx * local_target_argmax_stride
|
||||
+ block_idx,
|
||||
token_id,
|
||||
)
|
||||
tl.store(
|
||||
local_target_max_ptr + logit_idx * local_target_max_stride + block_idx,
|
||||
value,
|
||||
)
|
||||
elif draft_step_idx < num_speculative_steps:
|
||||
draft_logits = tl.load(
|
||||
draft_logits_ptr
|
||||
+ req_state_idx * draft_logits_stride_0
|
||||
+ draft_step_idx * draft_logits_stride_1
|
||||
+ block_offsets,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
).to(tl.float32)
|
||||
tl.store(
|
||||
out_draft_logits_ptr + logit_idx * out_draft_logits_stride + block_offsets,
|
||||
draft_logits,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _probabilistic_rejection_kernel(
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled_ptr,
|
||||
sampled_stride,
|
||||
# [num_reqs]
|
||||
rejected_steps_ptr,
|
||||
# [num_reqs]
|
||||
rejected_pos_ptr,
|
||||
# [num_logits]
|
||||
draft_sampled_ptr,
|
||||
# [num_logits, V]
|
||||
target_probs_ptr,
|
||||
target_probs_stride,
|
||||
# [num_logits, V]
|
||||
draft_probs_ptr,
|
||||
draft_probs_stride,
|
||||
# [num_logits, num_blocks]
|
||||
local_target_argmax_ptr,
|
||||
local_target_argmax_stride,
|
||||
# [num_logits, num_blocks]
|
||||
local_target_max_ptr,
|
||||
local_target_max_stride,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits_ptr,
|
||||
# [num_logits]
|
||||
pos_ptr,
|
||||
# [num_reqs]
|
||||
idx_mapping_ptr,
|
||||
# [max_num_reqs]
|
||||
temp_ptr,
|
||||
# [max_num_reqs]
|
||||
seeds_ptr,
|
||||
NUM_BLOCKS: tl.constexpr,
|
||||
PADDED_NUM_BLOCKS: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
|
||||
rejected_step = 0
|
||||
accepted = True
|
||||
for i in range(num_tokens - 1):
|
||||
if accepted:
|
||||
logit_idx = start_idx + i
|
||||
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
|
||||
if temp == 0.0:
|
||||
# Greedy sampling. Only accept the sampled draft token if
|
||||
# it exactly matches the target argmax.
|
||||
block_offsets = tl.arange(0, PADDED_NUM_BLOCKS)
|
||||
block_mask = block_offsets < NUM_BLOCKS
|
||||
local_max = tl.load(
|
||||
local_target_max_ptr
|
||||
+ logit_idx * local_target_max_stride
|
||||
+ block_offsets,
|
||||
mask=block_mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
max_block = tl.argmax(local_max, axis=0)
|
||||
target_argmax = tl.load(
|
||||
local_target_argmax_ptr
|
||||
+ logit_idx * local_target_argmax_stride
|
||||
+ max_block
|
||||
)
|
||||
accepted &= target_argmax == draft_sampled
|
||||
else:
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + logit_idx * target_probs_stride + draft_sampled
|
||||
).to(tl.float64)
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled
|
||||
).to(tl.float64)
|
||||
pos = tl.load(pos_ptr + logit_idx)
|
||||
u = tl_rand64(seed, pos, includes_zero=False)
|
||||
accepted &= target_prob > u * draft_prob
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
|
||||
rejected_step += accepted
|
||||
tl.store(rejected_steps_ptr + req_idx, rejected_step)
|
||||
pos_val = tl.load(pos_ptr + start_idx + rejected_step)
|
||||
tl.store(rejected_pos_ptr + req_idx, pos_val)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _compute_residual_logits_kernel(
|
||||
# [num_reqs, V]
|
||||
residual_logits_ptr,
|
||||
residual_logits_stride,
|
||||
# [num_logits, V]
|
||||
target_probs_ptr,
|
||||
target_probs_stride,
|
||||
# [num_logits, V]
|
||||
draft_probs_ptr,
|
||||
draft_probs_stride,
|
||||
# [num_logits, V]
|
||||
target_logits_ptr,
|
||||
target_logits_stride,
|
||||
# [num_reqs]
|
||||
rejected_step_ptr,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits_ptr,
|
||||
# [num_reqs]
|
||||
idx_mapping_ptr,
|
||||
# [max_num_reqs]
|
||||
temp_ptr,
|
||||
vocab_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
req_idx = tl.program_id(0)
|
||||
block_idx = tl.program_id(1)
|
||||
|
||||
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
|
||||
start_idx = tl.load(cu_num_logits_ptr + req_idx)
|
||||
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
|
||||
rejected_logit_idx = start_idx + tl.load(rejected_step_ptr + req_idx)
|
||||
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
|
||||
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = block_offsets < vocab_size
|
||||
|
||||
if temp == 0.0 or (rejected_logit_idx == end_idx - 1):
|
||||
# Greedy sampling / bonus token. In either case, use the
|
||||
# target logits directly to reduce numerical error.
|
||||
residual_logits = tl.load(
|
||||
target_logits_ptr
|
||||
+ rejected_logit_idx * target_logits_stride
|
||||
+ block_offsets,
|
||||
mask=mask,
|
||||
other=float("-inf"),
|
||||
)
|
||||
else:
|
||||
target_probs = tl.load(
|
||||
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
)
|
||||
draft_probs = tl.load(
|
||||
draft_probs_ptr + rejected_logit_idx * draft_probs_stride + block_offsets,
|
||||
mask=mask,
|
||||
other=0.0,
|
||||
)
|
||||
residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
|
||||
residual_logits = tl.log(residual_probs)
|
||||
|
||||
tl.store(
|
||||
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
|
||||
residual_logits,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
def probabilistic_rejection_sample(
|
||||
# [num_logits, V]
|
||||
target_logits: torch.Tensor,
|
||||
# [max_num_reqs, num_speculative_steps, V]
|
||||
draft_logits: torch.Tensor,
|
||||
# [num_logits]
|
||||
draft_sampled: torch.Tensor,
|
||||
# [num_reqs + 1]
|
||||
cu_num_logits: torch.Tensor,
|
||||
# [num_logits]
|
||||
pos: torch.Tensor,
|
||||
# [num_reqs]
|
||||
idx_mapping: torch.Tensor,
|
||||
# [num_logits]
|
||||
expanded_idx_mapping: torch.Tensor,
|
||||
# [num_logits]
|
||||
expanded_local_pos: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
temperature: torch.Tensor,
|
||||
# [max_num_reqs]
|
||||
seed: torch.Tensor,
|
||||
num_speculative_steps: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
num_reqs = cu_num_logits.shape[0] - 1
|
||||
num_logits, vocab_size = target_logits.shape
|
||||
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
|
||||
# Gather draft logits and target argmax for greedy sampling.
|
||||
gathered_draft_logits = target_logits.new_empty(target_logits.shape)
|
||||
local_target_argmax = target_logits.new_empty(
|
||||
num_logits, num_blocks, dtype=torch.int64
|
||||
)
|
||||
local_target_max = target_logits.new_empty(
|
||||
num_logits, num_blocks, dtype=torch.float32
|
||||
)
|
||||
_gather_draft_logits_and_target_argmax_kernel[(num_logits, num_blocks)](
|
||||
local_target_argmax,
|
||||
local_target_argmax.stride(0),
|
||||
local_target_max,
|
||||
local_target_max.stride(0),
|
||||
gathered_draft_logits,
|
||||
gathered_draft_logits.stride(0),
|
||||
target_logits,
|
||||
target_logits.stride(0),
|
||||
draft_logits,
|
||||
draft_logits.stride(0),
|
||||
draft_logits.stride(1),
|
||||
expanded_idx_mapping,
|
||||
expanded_local_pos,
|
||||
temperature,
|
||||
vocab_size,
|
||||
num_speculative_steps,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Compute target and draft probs.
|
||||
target_probs = torch.softmax(target_logits, dim=-1)
|
||||
draft_probs = torch.softmax(gathered_draft_logits, dim=-1)
|
||||
|
||||
# Rejection sample.
|
||||
# [num_reqs, num_speculative_steps + 1]
|
||||
sampled = draft_sampled.new_empty(
|
||||
num_reqs, num_speculative_steps + 1, dtype=torch.int64
|
||||
)
|
||||
# [num_reqs]
|
||||
rejected_steps = sampled.new_empty(num_reqs)
|
||||
# [num_reqs]
|
||||
rejected_pos = pos.new_empty(num_reqs)
|
||||
_probabilistic_rejection_kernel[(num_reqs,)](
|
||||
sampled,
|
||||
sampled.stride(0),
|
||||
rejected_steps,
|
||||
rejected_pos,
|
||||
draft_sampled,
|
||||
target_probs,
|
||||
target_probs.stride(0),
|
||||
draft_probs,
|
||||
draft_probs.stride(0),
|
||||
local_target_argmax,
|
||||
local_target_argmax.stride(0),
|
||||
local_target_max,
|
||||
local_target_max.stride(0),
|
||||
cu_num_logits,
|
||||
pos,
|
||||
idx_mapping,
|
||||
temperature,
|
||||
seed,
|
||||
num_warps=1,
|
||||
NUM_BLOCKS=num_blocks,
|
||||
PADDED_NUM_BLOCKS=triton.next_power_of_2(num_blocks),
|
||||
)
|
||||
|
||||
# Compute the logits and positions to resample the rejected/bonus
|
||||
# tokens from.
|
||||
# [num_reqs, vocab_size]
|
||||
residual_logits = target_logits.new_empty(num_reqs, vocab_size)
|
||||
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
|
||||
residual_logits,
|
||||
residual_logits.stride(0),
|
||||
target_probs,
|
||||
target_probs.stride(0),
|
||||
draft_probs,
|
||||
draft_probs.stride(0),
|
||||
target_logits,
|
||||
target_logits.stride(0),
|
||||
rejected_steps,
|
||||
cu_num_logits,
|
||||
idx_mapping,
|
||||
temperature,
|
||||
vocab_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
# Gumbel sample tokens from the residual distribution.
|
||||
resampled = gumbel_sample(
|
||||
residual_logits,
|
||||
idx_mapping,
|
||||
temperature,
|
||||
seed,
|
||||
rejected_pos,
|
||||
apply_temperature=False,
|
||||
)
|
||||
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
|
||||
|
||||
return sampled, rejected_steps + 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _flatten_sampled_kernel(
|
||||
# [num_logits]
|
||||
|
||||
Reference in New Issue
Block a user