[Kernel] Optimize sample_recovered_tokens_kernel (#34974)

Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
Xin Yang
2026-02-20 19:59:06 -08:00
committed by GitHub
parent 59c6233297
commit 7a5adad480
2 changed files with 179 additions and 34 deletions

View File

@@ -11,7 +11,11 @@ from tests.v1.sample.utils import create_allowed_token_ids
from vllm.platforms import current_platform
from vllm.v1.sample.logits_processor import LogitsProcessors
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler
from vllm.v1.sample.rejection_sampler import (
PLACEHOLDER_TOKEN_ID,
RejectionSampler,
sample_recovered_tokens,
)
from vllm.v1.sample.sampler import Sampler, SamplerOutput
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
@@ -518,6 +522,70 @@ def estimate_rejection_sampling_pdf(
return hist.hist
def native_sample_recovered_tokens(
max_spec_len: int,
num_draft_tokens: list[int],
cu_num_draft_tokens: torch.Tensor, # [batch_size]
draft_token_ids: torch.Tensor, # [num_tokens]
draft_probs: torch.Tensor | None, # [num_tokens, vocab_size]
target_probs: torch.Tensor, # [num_tokens, vocab_size]
sampling_metadata: SamplingMetadata,
device: torch.device,
) -> torch.Tensor:
batch_size = len(num_draft_tokens)
vocab_size = target_probs.shape[-1]
q = torch.empty(
(batch_size, vocab_size),
dtype=torch.float32,
device=device,
)
q.exponential_()
states = {
i: generator.get_state()
for i, generator in sampling_metadata.generators.items()
}
for i, generator in sampling_metadata.generators.items():
# Do not generate random numbers for requests with no draft tokens.
# This can be important for reproducibility.
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
# In order to generate the same exponential later, reset the CUDA RNG
# state because RNG state advances after each call.
generator.set_state(states[i])
inv_q = q.reciprocal()
out = torch.empty_like(draft_token_ids)
for req_idx in range(batch_size):
start_idx = 0 if req_idx == 0 else int(cu_num_draft_tokens[req_idx - 1].item())
end_idx = int(cu_num_draft_tokens[req_idx].item())
num_tokens = end_idx - start_idx
for pos in range(max_spec_len):
if pos >= num_tokens:
continue
token_idx = start_idx + pos
if draft_probs is None:
# prob is target_probs[token_idx] except draft_token_id is zeroed
prob = target_probs[token_idx].clone()
draft_token_id = draft_token_ids[token_idx]
prob[draft_token_id] = 0.0
else:
prob = (target_probs[token_idx] - draft_probs[token_idx]).clamp_min_(
0.0
)
score = prob * inv_q[req_idx]
recovered_id = torch.argmax(score, dim=-1)
out[token_idx] = recovered_id
return out
def _test_masked_logits(
rejection_sampler,
batch_size: int,
@@ -778,3 +846,60 @@ def test_allowed_token_ids(rejection_sampler):
device=logits.device,
)
assert torch.equal(output.sampled_token_ids, expected)
@pytest.mark.parametrize("batch_size", [1, 100])
@pytest.mark.parametrize("vocab_size", [100, 8192, 10000])
@pytest.mark.parametrize("max_spec_len", [1, 3])
@pytest.mark.parametrize("no_draft_probs", [True, False])
def test_sample_recovered_tokens(
batch_size: int, vocab_size: int, max_spec_len: int, no_draft_probs: bool
):
num_tokens = batch_size * max_spec_len
# Create random draft probabilities.
draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
# Create random target probabilities.
target_logits = torch.rand(
num_tokens, vocab_size, dtype=torch.float32, device=DEVICE
)
target_probs = F.softmax(target_logits, dim=-1)
# Randomly sample draft token ids from draft probs
draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32)
temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE)
generators = {
i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size)
}
sampling_metadata = create_sampling_metadata(
all_greedy=False, temperature=temperature, generators=generators
)
spec_decode_metadata = create_spec_decode_metadata(
draft_token_ids.reshape(batch_size, max_spec_len).tolist(), target_logits
)
ref_recovered_token_ids = native_sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
)
recovered_token_ids = sample_recovered_tokens(
max_spec_len,
spec_decode_metadata.num_draft_tokens,
spec_decode_metadata.cu_num_draft_tokens,
draft_token_ids,
None if no_draft_probs else draft_probs,
target_probs,
sampling_metadata,
device=DEVICE,
)
assert torch.equal(recovered_token_ids, ref_recovered_token_ids)

View File

@@ -623,16 +623,19 @@ def sample_recovered_tokens(
if num_draft_tokens[i] > 0:
q[i].exponential_(generator=generator)
inv_q = q.reciprocal()
recovered_token_ids = torch.empty_like(draft_token_ids)
BLOCK_SIZE = 8192
sample_recovered_tokens_kernel[(batch_size, max_spec_len)](
recovered_token_ids,
cu_num_draft_tokens,
draft_token_ids,
draft_probs,
target_probs,
q,
inv_q,
vocab_size,
triton.next_power_of_2(vocab_size),
BLOCK_SIZE,
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids
@@ -776,9 +779,9 @@ def sample_recovered_tokens_kernel(
draft_token_ids_ptr, # [num_tokens]
draft_probs_ptr, # [num_tokens, vocab_size] or None
target_probs_ptr, # [num_tokens, vocab_size]
q_ptr, # [batch_size, vocab_size]
inv_q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
@@ -791,33 +794,50 @@ def sample_recovered_tokens_kernel(
if pos >= num_draft_tokens:
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
other=0,
)
else:
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
token_idx = start_idx + pos
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + token_idx)
max_val = float("-inf")
recovered_id = 0
for v in range(0, vocab_size, BLOCK_SIZE):
vocab_offset = v + tl.arange(0, BLOCK_SIZE)
vocab_mask = vocab_offset < vocab_size
if NO_DRAFT_PROBS:
prob = tl.load(
target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=(vocab_mask & (vocab_offset != draft_token_id)),
other=0.0,
)
else:
draft_prob = tl.load(
draft_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
)
target_prob = tl.load(
target_probs_ptr + token_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
)
prob = tl.maximum(target_prob - draft_prob, 0.0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
inv_q = tl.load(
inv_q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_mask,
other=0.0,
)
# Local tile reduction
score = prob * inv_q
local_max, local_id = tl.max(score, axis=0, return_indices=True)
if local_max > max_val:
max_val = local_max
recovered_id = v + local_id
tl.store(output_token_ids_ptr + token_idx, recovered_id)