From 7a5adad48026d130348064ae7d41072ff999d1bf Mon Sep 17 00:00:00 2001 From: Xin Yang <105740670+xyang16@users.noreply.github.com> Date: Fri, 20 Feb 2026 19:59:06 -0800 Subject: [PATCH] [Kernel] Optimize sample_recovered_tokens_kernel (#34974) Signed-off-by: Xin Yang --- tests/v1/sample/test_rejection_sampler.py | 127 +++++++++++++++++++++- vllm/v1/sample/rejection_sampler.py | 86 +++++++++------ 2 files changed, 179 insertions(+), 34 deletions(-) diff --git a/tests/v1/sample/test_rejection_sampler.py b/tests/v1/sample/test_rejection_sampler.py index d8ae57984..38ffc58e2 100644 --- a/tests/v1/sample/test_rejection_sampler.py +++ b/tests/v1/sample/test_rejection_sampler.py @@ -11,7 +11,11 @@ from tests.v1.sample.utils import create_allowed_token_ids from vllm.platforms import current_platform from vllm.v1.sample.logits_processor import LogitsProcessors from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.sample.rejection_sampler import PLACEHOLDER_TOKEN_ID, RejectionSampler +from vllm.v1.sample.rejection_sampler import ( + PLACEHOLDER_TOKEN_ID, + RejectionSampler, + sample_recovered_tokens, +) from vllm.v1.sample.sampler import Sampler, SamplerOutput from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -518,6 +522,70 @@ def estimate_rejection_sampling_pdf( return hist.hist +def native_sample_recovered_tokens( + max_spec_len: int, + num_draft_tokens: list[int], + cu_num_draft_tokens: torch.Tensor, # [batch_size] + draft_token_ids: torch.Tensor, # [num_tokens] + draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] + target_probs: torch.Tensor, # [num_tokens, vocab_size] + sampling_metadata: SamplingMetadata, + device: torch.device, +) -> torch.Tensor: + batch_size = len(num_draft_tokens) + vocab_size = target_probs.shape[-1] + + q = torch.empty( + (batch_size, vocab_size), + dtype=torch.float32, + device=device, + ) + q.exponential_() + + states = { + i: generator.get_state() + for i, generator in sampling_metadata.generators.items() + } + for i, generator in sampling_metadata.generators.items(): + # Do not generate random numbers for requests with no draft tokens. + # This can be important for reproducibility. + if num_draft_tokens[i] > 0: + q[i].exponential_(generator=generator) + + # In order to generate the same exponential later, reset the CUDA RNG + # state because RNG state advances after each call. + generator.set_state(states[i]) + + inv_q = q.reciprocal() + + out = torch.empty_like(draft_token_ids) + + for req_idx in range(batch_size): + start_idx = 0 if req_idx == 0 else int(cu_num_draft_tokens[req_idx - 1].item()) + end_idx = int(cu_num_draft_tokens[req_idx].item()) + num_tokens = end_idx - start_idx + + for pos in range(max_spec_len): + if pos >= num_tokens: + continue + token_idx = start_idx + pos + + if draft_probs is None: + # prob is target_probs[token_idx] except draft_token_id is zeroed + prob = target_probs[token_idx].clone() + draft_token_id = draft_token_ids[token_idx] + prob[draft_token_id] = 0.0 + else: + prob = (target_probs[token_idx] - draft_probs[token_idx]).clamp_min_( + 0.0 + ) + + score = prob * inv_q[req_idx] + recovered_id = torch.argmax(score, dim=-1) + out[token_idx] = recovered_id + return out + + def _test_masked_logits( rejection_sampler, batch_size: int, @@ -778,3 +846,60 @@ def test_allowed_token_ids(rejection_sampler): device=logits.device, ) assert torch.equal(output.sampled_token_ids, expected) + + +@pytest.mark.parametrize("batch_size", [1, 100]) +@pytest.mark.parametrize("vocab_size", [100, 8192, 10000]) +@pytest.mark.parametrize("max_spec_len", [1, 3]) +@pytest.mark.parametrize("no_draft_probs", [True, False]) +def test_sample_recovered_tokens( + batch_size: int, vocab_size: int, max_spec_len: int, no_draft_probs: bool +): + num_tokens = batch_size * max_spec_len + + # Create random draft probabilities. + draft_probs = torch.rand(num_tokens, vocab_size, dtype=torch.float32, device=DEVICE) + draft_probs = F.softmax(draft_probs, dim=-1) + + # Create random target probabilities. + target_logits = torch.rand( + num_tokens, vocab_size, dtype=torch.float32, device=DEVICE + ) + target_probs = F.softmax(target_logits, dim=-1) + + # Randomly sample draft token ids from draft probs + draft_token_ids = torch.multinomial(draft_probs, num_samples=1).to(torch.int32) + + temperature = torch.ones(batch_size, dtype=torch.float32, device=DEVICE) + generators = { + i: torch.Generator(device=DEVICE).manual_seed(i) for i in range(batch_size) + } + sampling_metadata = create_sampling_metadata( + all_greedy=False, temperature=temperature, generators=generators + ) + + spec_decode_metadata = create_spec_decode_metadata( + draft_token_ids.reshape(batch_size, max_spec_len).tolist(), target_logits + ) + + ref_recovered_token_ids = native_sample_recovered_tokens( + max_spec_len, + spec_decode_metadata.num_draft_tokens, + spec_decode_metadata.cu_num_draft_tokens, + draft_token_ids, + None if no_draft_probs else draft_probs, + target_probs, + sampling_metadata, + device=DEVICE, + ) + recovered_token_ids = sample_recovered_tokens( + max_spec_len, + spec_decode_metadata.num_draft_tokens, + spec_decode_metadata.cu_num_draft_tokens, + draft_token_ids, + None if no_draft_probs else draft_probs, + target_probs, + sampling_metadata, + device=DEVICE, + ) + assert torch.equal(recovered_token_ids, ref_recovered_token_ids) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index b57c93e29..1efceba38 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -623,16 +623,19 @@ def sample_recovered_tokens( if num_draft_tokens[i] > 0: q[i].exponential_(generator=generator) + inv_q = q.reciprocal() + recovered_token_ids = torch.empty_like(draft_token_ids) + BLOCK_SIZE = 8192 sample_recovered_tokens_kernel[(batch_size, max_spec_len)]( recovered_token_ids, cu_num_draft_tokens, draft_token_ids, draft_probs, target_probs, - q, + inv_q, vocab_size, - triton.next_power_of_2(vocab_size), + BLOCK_SIZE, NO_DRAFT_PROBS=draft_probs is None, ) return recovered_token_ids @@ -776,9 +779,9 @@ def sample_recovered_tokens_kernel( draft_token_ids_ptr, # [num_tokens] draft_probs_ptr, # [num_tokens, vocab_size] or None target_probs_ptr, # [num_tokens, vocab_size] - q_ptr, # [batch_size, vocab_size] + inv_q_ptr, # [batch_size, vocab_size] vocab_size, - PADDED_VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, NO_DRAFT_PROBS: tl.constexpr, ): req_idx = tl.program_id(0) @@ -791,33 +794,50 @@ def sample_recovered_tokens_kernel( if pos >= num_draft_tokens: return - vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE) - if NO_DRAFT_PROBS: - draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos) - prob = tl.load( - target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)), - other=0, - ) - else: - draft_prob = tl.load( - draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0, - ) - target_prob = tl.load( - target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=0, - ) - prob = tl.maximum(target_prob - draft_prob, 0) - # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because - # `tl.argmax` will select the maximum value. + token_idx = start_idx + pos - q = tl.load( - q_ptr + req_idx * vocab_size + vocab_offset, - mask=vocab_offset < vocab_size, - other=float("-inf"), - ) - recovered_id = tl.argmax(prob / q, axis=-1) - tl.store(output_token_ids_ptr + start_idx + pos, recovered_id) + if NO_DRAFT_PROBS: + draft_token_id = tl.load(draft_token_ids_ptr + token_idx) + + max_val = float("-inf") + recovered_id = 0 + for v in range(0, vocab_size, BLOCK_SIZE): + vocab_offset = v + tl.arange(0, BLOCK_SIZE) + vocab_mask = vocab_offset < vocab_size + + if NO_DRAFT_PROBS: + prob = tl.load( + target_probs_ptr + token_idx * vocab_size + vocab_offset, + mask=(vocab_mask & (vocab_offset != draft_token_id)), + other=0.0, + ) + else: + draft_prob = tl.load( + draft_probs_ptr + token_idx * vocab_size + vocab_offset, + mask=vocab_mask, + other=0.0, + ) + target_prob = tl.load( + target_probs_ptr + token_idx * vocab_size + vocab_offset, + mask=vocab_mask, + other=0.0, + ) + prob = tl.maximum(target_prob - draft_prob, 0.0) + # NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because + # `tl.argmax` will select the maximum value. + + inv_q = tl.load( + inv_q_ptr + req_idx * vocab_size + vocab_offset, + mask=vocab_mask, + other=0.0, + ) + + # Local tile reduction + score = prob * inv_q + local_max, local_id = tl.max(score, axis=0, return_indices=True) + + if local_max > max_val: + max_val = local_max + recovered_id = v + local_id + + tl.store(output_token_ids_ptr + token_idx, recovered_id)