[perf] v1/spec_decode: skip softmax for all-greedy rejection sampling (#32852)

Signed-off-by: hdj <1293066020@qq.com>
This commit is contained in:
caozuoba
2026-01-31 17:51:26 +08:00
committed by GitHub
parent 527bcd14d4
commit 8980001c93

View File

@@ -136,8 +136,6 @@ class RejectionSampler(nn.Module):
metadata.cu_num_draft_tokens,
sampling_metadata,
)
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
output_token_ids = rejection_sample(
metadata.draft_token_ids,
@@ -145,7 +143,7 @@ class RejectionSampler(nn.Module):
metadata.max_spec_len,
metadata.cu_num_draft_tokens,
draft_probs,
target_probs,
target_logits,
bonus_token_ids,
sampling_metadata,
)
@@ -353,7 +351,7 @@ def rejection_sample(
# [num_tokens, vocab_size]
draft_probs: torch.Tensor | None,
# [num_tokens, vocab_size]
target_probs: torch.Tensor,
target_logits: torch.Tensor,
# [batch_size, 1]
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
@@ -361,17 +359,16 @@ def rejection_sample(
assert draft_token_ids.ndim == 1
assert draft_probs is None or draft_probs.ndim == 2
assert cu_num_draft_tokens.ndim == 1
assert target_probs.ndim == 2
assert target_logits.ndim == 2
batch_size = len(num_draft_tokens)
num_tokens = draft_token_ids.shape[0]
vocab_size = target_probs.shape[-1]
device = target_probs.device
vocab_size = target_logits.shape[-1]
device = target_logits.device
assert draft_token_ids.is_contiguous()
assert draft_probs is None or draft_probs.is_contiguous()
assert target_probs.is_contiguous()
assert bonus_token_ids.is_contiguous()
assert target_probs.shape == (num_tokens, vocab_size)
assert target_logits.shape == (num_tokens, vocab_size)
# Create output buffer.
output_token_ids = torch.full(
@@ -387,7 +384,7 @@ def rejection_sample(
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
target_argmax = target_logits.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
@@ -400,6 +397,10 @@ def rejection_sample(
if sampling_metadata.all_greedy:
return output_token_ids
# Compute probability distribution from target logits.
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
assert target_probs.is_contiguous()
# Generate uniform probabilities for rejection sampling.
# [num_tokens]
uniform_probs = generate_uniform_probs(