[perf] v1/spec_decode: skip softmax for all-greedy rejection sampling (#32852)
Signed-off-by: hdj <1293066020@qq.com>
This commit is contained in:
@@ -136,8 +136,6 @@ class RejectionSampler(nn.Module):
|
||||
metadata.cu_num_draft_tokens,
|
||||
sampling_metadata,
|
||||
)
|
||||
# Compute probability distribution from target logits.
|
||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
output_token_ids = rejection_sample(
|
||||
metadata.draft_token_ids,
|
||||
@@ -145,7 +143,7 @@ class RejectionSampler(nn.Module):
|
||||
metadata.max_spec_len,
|
||||
metadata.cu_num_draft_tokens,
|
||||
draft_probs,
|
||||
target_probs,
|
||||
target_logits,
|
||||
bonus_token_ids,
|
||||
sampling_metadata,
|
||||
)
|
||||
@@ -353,7 +351,7 @@ def rejection_sample(
|
||||
# [num_tokens, vocab_size]
|
||||
draft_probs: torch.Tensor | None,
|
||||
# [num_tokens, vocab_size]
|
||||
target_probs: torch.Tensor,
|
||||
target_logits: torch.Tensor,
|
||||
# [batch_size, 1]
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
@@ -361,17 +359,16 @@ def rejection_sample(
|
||||
assert draft_token_ids.ndim == 1
|
||||
assert draft_probs is None or draft_probs.ndim == 2
|
||||
assert cu_num_draft_tokens.ndim == 1
|
||||
assert target_probs.ndim == 2
|
||||
assert target_logits.ndim == 2
|
||||
|
||||
batch_size = len(num_draft_tokens)
|
||||
num_tokens = draft_token_ids.shape[0]
|
||||
vocab_size = target_probs.shape[-1]
|
||||
device = target_probs.device
|
||||
vocab_size = target_logits.shape[-1]
|
||||
device = target_logits.device
|
||||
assert draft_token_ids.is_contiguous()
|
||||
assert draft_probs is None or draft_probs.is_contiguous()
|
||||
assert target_probs.is_contiguous()
|
||||
assert bonus_token_ids.is_contiguous()
|
||||
assert target_probs.shape == (num_tokens, vocab_size)
|
||||
assert target_logits.shape == (num_tokens, vocab_size)
|
||||
|
||||
# Create output buffer.
|
||||
output_token_ids = torch.full(
|
||||
@@ -387,7 +384,7 @@ def rejection_sample(
|
||||
is_greedy = sampling_metadata.temperature == GREEDY_TEMPERATURE
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
target_argmax = target_logits.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
@@ -400,6 +397,10 @@ def rejection_sample(
|
||||
if sampling_metadata.all_greedy:
|
||||
return output_token_ids
|
||||
|
||||
# Compute probability distribution from target logits.
|
||||
target_probs = target_logits.softmax(dim=-1, dtype=torch.float32)
|
||||
assert target_probs.is_contiguous()
|
||||
|
||||
# Generate uniform probabilities for rejection sampling.
|
||||
# [num_tokens]
|
||||
uniform_probs = generate_uniform_probs(
|
||||
|
||||
Reference in New Issue
Block a user