From 8980001c9333c617f2bc3b2c60926b02434067a7 Mon Sep 17 00:00:00 2001 From: caozuoba <44251931+caozuoba@users.noreply.github.com> Date: Sat, 31 Jan 2026 17:51:26 +0800 Subject: [PATCH] [perf] v1/spec_decode: skip softmax for all-greedy rejection sampling (#32852) Signed-off-by: hdj <1293066020@qq.com> --- vllm/v1/sample/rejection_sampler.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/vllm/v1/sample/rejection_sampler.py b/vllm/v1/sample/rejection_sampler.py index 16d01d518..b57c93e29 100644 --- a/vllm/v1/sample/rejection_sampler.py +++ b/vllm/v1/sample/rejection_sampler.py @@ -136,8 +136,6 @@ class RejectionSampler(nn.Module): metadata.cu_num_draft_tokens, sampling_metadata, ) - # Compute probability distribution from target logits. - target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) output_token_ids = rejection_sample( metadata.draft_token_ids, @@ -145,7 +143,7 @@ class RejectionSampler(nn.Module): metadata.max_spec_len, metadata.cu_num_draft_tokens, draft_probs, - target_probs, + target_logits, bonus_token_ids, sampling_metadata, ) @@ -353,7 +351,7 @@ def rejection_sample( # [num_tokens, vocab_size] draft_probs: torch.Tensor | None, # [num_tokens, vocab_size] - target_probs: torch.Tensor, + target_logits: torch.Tensor, # [batch_size, 1] bonus_token_ids: torch.Tensor, sampling_metadata: SamplingMetadata, @@ -361,17 +359,16 @@ def rejection_sample( assert draft_token_ids.ndim == 1 assert draft_probs is None or draft_probs.ndim == 2 assert cu_num_draft_tokens.ndim == 1 - assert target_probs.ndim == 2 + assert target_logits.ndim == 2 batch_size = len(num_draft_tokens) num_tokens = draft_token_ids.shape[0] - vocab_size = target_probs.shape[-1] - device = target_probs.device + vocab_size = target_logits.shape[-1] + device = target_logits.device assert draft_token_ids.is_contiguous() assert draft_probs is None or draft_probs.is_contiguous() - assert target_probs.is_contiguous() assert bonus_token_ids.is_contiguous() - assert target_probs.shape == (num_tokens, vocab_size) + assert target_logits.shape == (num_tokens, vocab_size) # Create output buffer. output_token_ids = torch.full( @@ -387,7 +384,7 @@ def rejection_sample( is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE if not sampling_metadata.all_random: # Rejection sampling for greedy sampling requests. - target_argmax = target_probs.argmax(dim=-1) + target_argmax = target_logits.argmax(dim=-1) rejection_greedy_sample_kernel[(batch_size,)]( output_token_ids, cu_num_draft_tokens, @@ -400,6 +397,10 @@ def rejection_sample( if sampling_metadata.all_greedy: return output_token_ids + # Compute probability distribution from target logits. + target_probs = target_logits.softmax(dim=-1, dtype=torch.float32) + assert target_probs.is_contiguous() + # Generate uniform probabilities for rejection sampling. # [num_tokens] uniform_probs = generate_uniform_probs(