[V1][Spec Decode] Always use argmax for sampling draft tokens (#16899)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-04-23 14:57:43 -07:00
committed by GitHub
parent 32d4b669d0
commit 41fb013d29
3 changed files with 18 additions and 23 deletions

View File

@@ -226,7 +226,7 @@ def rejection_sample(
is_greedy,
max_spec_len,
vocab_size,
IS_NGRAM=draft_probs is None,
NO_DRAFT_PROBS=draft_probs is None,
num_warps=1,
)
return output_token_ids
@@ -423,7 +423,7 @@ def sample_recovered_tokens(
q,
vocab_size,
triton.next_power_of_2(vocab_size),
IS_NGRAM=draft_probs is None,
NO_DRAFT_PROBS=draft_probs is None,
)
return recovered_token_ids
@@ -490,7 +490,7 @@ def rejection_random_sample_kernel(
is_greedy_ptr, # [batch_size]
max_spec_len,
vocab_size,
IS_NGRAM: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
is_greedy = tl.load(is_greedy_ptr + req_idx)
@@ -509,7 +509,7 @@ def rejection_random_sample_kernel(
for pos in range(num_draft_tokens):
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
if IS_NGRAM:
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
@@ -575,7 +575,7 @@ def sample_recovered_tokens_kernel(
q_ptr, # [batch_size, vocab_size]
vocab_size,
PADDED_VOCAB_SIZE: tl.constexpr,
IS_NGRAM: tl.constexpr,
NO_DRAFT_PROBS: tl.constexpr,
):
req_idx = tl.program_id(0)
if req_idx == 0:
@@ -591,7 +591,7 @@ def sample_recovered_tokens_kernel(
return
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if IS_NGRAM:
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
orig_prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
draft_token_id)
@@ -624,7 +624,7 @@ def sample_recovered_tokens_kernel(
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
if IS_NGRAM:
if NO_DRAFT_PROBS:
# Restore the original probability.
tl.store(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id,