[Model Runner V2] Spec decode rejection sampler greedy support (#37238)

Signed-off-by: Giancarlo Delfin <gdelfin@inferact.ai>
This commit is contained in:
Giancarlo Delfin
2026-03-18 15:59:03 -07:00
committed by GitHub
parent 9482b0b085
commit 04244fd0e1
2 changed files with 207 additions and 71 deletions

View File

@@ -821,9 +821,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logits,
input_batch,
# Draft logits are needed for probabilistic rejection sampling.
self.req_states.draft_logits[input_batch.idx_mapping]
if self.req_states.draft_logits is not None
else None,
self.req_states.draft_logits,
)
# Get the number of sampled and rejected tokens.

View File

@@ -68,55 +68,158 @@ def strict_rejection_sample(
@triton.jit
def _probabilistic_rejection_sample_kernel(
def _gather_draft_logits_and_target_argmax_kernel(
local_target_argmax_ptr,
local_target_argmax_stride,
local_target_max_ptr,
local_target_max_stride,
# [num_logits, V]
out_draft_logits_ptr,
out_draft_logits_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [max_num_reqs, num_speculative_steps, V]
draft_logits_ptr,
draft_logits_stride_0,
draft_logits_stride_1,
# [num_logits]
expanded_idx_mapping_ptr,
# [num_logits]
expanded_local_pos_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
num_speculative_steps,
BLOCK_SIZE: tl.constexpr,
):
logit_idx = tl.program_id(0)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
draft_step_idx = tl.load(expanded_local_pos_ptr + logit_idx)
block_idx = tl.program_id(1)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if temp == 0.0:
# Greedy sampling. Get the target logits argmax.
target_logits = tl.load(
target_logits_ptr + logit_idx * target_logits_stride + block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
value, idx = tl.max(target_logits, axis=0, return_indices=True)
token_id = block_idx * BLOCK_SIZE + idx
tl.store(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ block_idx,
token_id,
)
tl.store(
local_target_max_ptr + logit_idx * local_target_max_stride + block_idx,
value,
)
elif draft_step_idx < num_speculative_steps:
draft_logits = tl.load(
draft_logits_ptr
+ req_state_idx * draft_logits_stride_0
+ draft_step_idx * draft_logits_stride_1
+ block_offsets,
mask=mask,
other=float("-inf"),
).to(tl.float32)
tl.store(
out_draft_logits_ptr + logit_idx * out_draft_logits_stride + block_offsets,
draft_logits,
mask=mask,
)
@triton.jit
def _probabilistic_rejection_kernel(
# [num_reqs, num_speculative_steps + 1]
sampled_ptr,
sampled_stride,
# [num_reqs]
rejected_steps_ptr,
# [num_reqs]
rejected_pos_ptr,
# [num_logits]
draft_sampled_ptr,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_reqs, num_speculative_steps, V]
# [num_logits, V]
draft_probs_ptr,
draft_probs_stride_0,
draft_probs_stride_1,
draft_probs_stride,
# [num_logits, num_blocks]
local_target_argmax_ptr,
local_target_argmax_stride,
# [num_logits, num_blocks]
local_target_max_ptr,
local_target_max_stride,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [num_reqs]
# [max_num_reqs]
temp_ptr,
# [max_num_reqs]
seeds_ptr,
NUM_BLOCKS: tl.constexpr,
PADDED_NUM_BLOCKS: tl.constexpr,
):
req_idx = tl.program_id(0)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
num_tokens = tl.load(cu_num_logits_ptr + req_idx + 1) - start_idx
seed = tl.load(seeds_ptr + tl.load(idx_mapping_ptr + req_idx))
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
seed = tl.load(seeds_ptr + req_state_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
rejected_step = 0
accepted = True
for i in range(num_tokens - 1):
if accepted:
draft_sampled = tl.load(draft_sampled_ptr + start_idx + i + 1)
target_prob = tl.load(
target_probs_ptr + (start_idx + i) * target_probs_stride + draft_sampled
)
draft_prob = tl.load(
draft_probs_ptr
+ req_idx * draft_probs_stride_0
+ i * draft_probs_stride_1
+ draft_sampled
)
pos = tl.load(pos_ptr + start_idx + i)
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
accepted &= target_prob > u * draft_prob
logit_idx = start_idx + i
draft_sampled = tl.load(draft_sampled_ptr + logit_idx + 1)
if temp == 0.0:
# Greedy sampling. Only accept the sampled draft token if
# it exactly matches the target argmax.
block_offsets = tl.arange(0, PADDED_NUM_BLOCKS)
block_mask = block_offsets < NUM_BLOCKS
local_max = tl.load(
local_target_max_ptr
+ logit_idx * local_target_max_stride
+ block_offsets,
mask=block_mask,
other=float("-inf"),
)
max_block = tl.argmax(local_max, axis=0)
target_argmax = tl.load(
local_target_argmax_ptr
+ logit_idx * local_target_argmax_stride
+ max_block
)
accepted &= target_argmax == draft_sampled
else:
target_prob = tl.load(
target_probs_ptr + logit_idx * target_probs_stride + draft_sampled
)
draft_prob = tl.load(
draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled
)
pos = tl.load(pos_ptr + logit_idx)
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
accepted &= target_prob > u * draft_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted
tl.store(rejected_steps_ptr + req_idx, rejected_step)
pos_val = tl.load(pos_ptr + start_idx + rejected_step)
tl.store(rejected_pos_ptr + req_idx, pos_val)
@triton.jit
@@ -124,63 +227,60 @@ def _compute_residual_logits_kernel(
# [num_reqs, V]
residual_logits_ptr,
residual_logits_stride,
# [num_reqs]
residual_pos_ptr,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_logits, V]
target_probs_ptr,
target_probs_stride,
# [num_reqs, num_speculative_steps, V]
# [num_logits, V]
draft_probs_ptr,
draft_probs_stride_0,
draft_probs_stride_1,
draft_probs_stride,
# [num_logits, V]
target_logits_ptr,
target_logits_stride,
# [num_reqs]
rejected_step_ptr,
# [num_reqs + 1]
cu_num_logits_ptr,
# [num_logits]
pos_ptr,
# [num_reqs]
idx_mapping_ptr,
# [max_num_reqs]
temp_ptr,
vocab_size,
BLOCK_SIZE: tl.constexpr,
):
req_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + req_idx)
start_idx = tl.load(cu_num_logits_ptr + req_idx)
end_idx = tl.load(cu_num_logits_ptr + req_idx + 1)
rejected_draft_step = tl.load(rejected_step_ptr + req_idx)
rejected_logit_idx = start_idx + rejected_draft_step
rejected_logit_idx = start_idx + tl.load(rejected_step_ptr + req_idx)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
block_offsets = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = block_offsets < vocab_size
if rejected_logit_idx < end_idx - 1:
if temp == 0.0 or (rejected_logit_idx == end_idx - 1):
# Greedy sampling / bonus token. In either case, use the
# target logits directly to reduce numerical error.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=float("-inf"),
)
else:
target_probs = tl.load(
target_probs_ptr + rejected_logit_idx * target_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
draft_probs = tl.load(
draft_probs_ptr
+ req_idx * draft_probs_stride_0
+ rejected_draft_step * draft_probs_stride_1
+ block_offsets,
draft_probs_ptr + rejected_logit_idx * draft_probs_stride + block_offsets,
mask=mask,
other=0.0,
)
residual_probs = tl.maximum(target_probs - draft_probs, 0.0)
residual_logits = tl.log(residual_probs)
else:
# This is a bonus token. Directly return the target logits.
residual_logits = tl.load(
target_logits_ptr
+ rejected_logit_idx * target_logits_stride
+ block_offsets,
mask=mask,
other=0.0,
)
tl.store(
residual_logits_ptr + req_idx * residual_logits_stride + block_offsets,
@@ -188,18 +288,13 @@ def _compute_residual_logits_kernel(
mask=mask,
)
# First block computes the residual logit positions.
if block_idx == 0:
pos_val = tl.load(pos_ptr + rejected_logit_idx)
tl.store(residual_pos_ptr + req_idx, pos_val)
def probabilistic_rejection_sample(
# [num_draft_tokens + num_reqs, V]
# [num_logits, V]
target_logits: torch.Tensor,
# [num_reqs, num_speculative_steps, V]
# [max_num_reqs, num_speculative_steps, V]
draft_logits: torch.Tensor,
# [num_draft_tokens + num_reqs]
# [num_logits]
draft_sampled: torch.Tensor,
# [num_reqs + 1]
cu_num_logits: torch.Tensor,
@@ -207,16 +302,53 @@ def probabilistic_rejection_sample(
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
# [num_logits]
expanded_idx_mapping: torch.Tensor,
# [num_logits]
expanded_local_pos: torch.Tensor,
# [max_num_reqs]
temperature: torch.Tensor,
# [max_num_reqs]
seed: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
vocab_size = target_logits.shape[-1]
num_logits, vocab_size = target_logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
# Gather draft logits and target argmax for greedy sampling.
gathered_draft_logits = target_logits.new_empty(target_logits.shape)
local_target_argmax = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.int64
)
local_target_max = target_logits.new_empty(
num_logits, num_blocks, dtype=torch.float32
)
_gather_draft_logits_and_target_argmax_kernel[(num_logits, num_blocks)](
local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
gathered_draft_logits,
gathered_draft_logits.stride(0),
target_logits,
target_logits.stride(0),
draft_logits,
draft_logits.stride(0),
draft_logits.stride(1),
expanded_idx_mapping,
expanded_local_pos,
temperature,
vocab_size,
num_speculative_steps,
BLOCK_SIZE=BLOCK_SIZE,
)
# Compute target and draft probs.
target_probs = torch.softmax(target_logits, dim=-1)
draft_probs = torch.softmax(draft_logits, dim=-1)
draft_probs = torch.softmax(gathered_draft_logits, dim=-1)
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
@@ -225,45 +357,49 @@ def probabilistic_rejection_sample(
)
# [num_reqs]
rejected_steps = sampled.new_empty(num_reqs)
_probabilistic_rejection_sample_kernel[(num_reqs,)](
# [num_reqs]
rejected_pos = pos.new_empty(num_reqs)
_probabilistic_rejection_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
rejected_steps,
rejected_pos,
draft_sampled,
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
draft_probs.stride(1),
local_target_argmax,
local_target_argmax.stride(0),
local_target_max,
local_target_max.stride(0),
cu_num_logits,
pos,
idx_mapping,
temperature,
seed,
num_warps=1,
NUM_BLOCKS=num_blocks,
PADDED_NUM_BLOCKS=triton.next_power_of_2(num_blocks),
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits = target_logits.new_empty(num_reqs, vocab_size)
# [num_reqs]
residual_pos = pos.new_empty(num_reqs)
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
residual_logits,
residual_logits.stride(0),
residual_pos,
target_logits,
target_logits.stride(0),
target_probs,
target_probs.stride(0),
draft_probs,
draft_probs.stride(0),
draft_probs.stride(1),
target_logits,
target_logits.stride(0),
rejected_steps,
cu_num_logits,
pos,
idx_mapping,
temperature,
vocab_size,
BLOCK_SIZE=BLOCK_SIZE,
)
@@ -274,7 +410,7 @@ def probabilistic_rejection_sample(
idx_mapping,
temperature,
seed,
residual_pos,
rejected_pos,
apply_temperature=False,
)
sampled.scatter_(1, rejected_steps.unsqueeze(1), resampled.unsqueeze(1))
@@ -333,6 +469,8 @@ class RejectionSampler:
input_batch.cu_num_logits,
pos,
input_batch.idx_mapping,
input_batch.expanded_idx_mapping,
input_batch.expanded_local_pos,
self.sampler.sampling_states.temperature.gpu,
self.sampler.sampling_states.seeds.gpu,
self.num_speculative_steps,