Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -54,7 +54,7 @@ class RejectionSampler(nn.Module):
|
||||
bonus_token_ids: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
'''
|
||||
"""
|
||||
Args:
|
||||
metadata:
|
||||
Metadata for spec decoding.
|
||||
@@ -81,7 +81,7 @@ class RejectionSampler(nn.Module):
|
||||
Returns:
|
||||
output_token_ids (torch.Tensor):
|
||||
A tensor containing the final output token IDs.
|
||||
'''
|
||||
"""
|
||||
assert metadata.max_spec_len <= MAX_SPEC_LEN
|
||||
# [num_tokens, vocab_size]
|
||||
# NOTE(woosuk): `target_logits` can be updated in place inside the
|
||||
@@ -123,11 +123,11 @@ class RejectionSampler(nn.Module):
|
||||
"""
|
||||
output_token_ids_np = output_token_ids.cpu().numpy()
|
||||
# Create mask for valid tokens.
|
||||
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
|
||||
(output_token_ids_np < vocab_size))
|
||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||
output_token_ids_np < vocab_size
|
||||
)
|
||||
outputs = [
|
||||
row[valid_mask[i]].tolist()
|
||||
for i, row in enumerate(output_token_ids_np)
|
||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||
]
|
||||
return outputs
|
||||
|
||||
@@ -178,7 +178,7 @@ def rejection_sample(
|
||||
if not sampling_metadata.all_random:
|
||||
# Rejection sampling for greedy sampling requests.
|
||||
target_argmax = target_probs.argmax(dim=-1)
|
||||
rejection_greedy_sample_kernel[(batch_size, )](
|
||||
rejection_greedy_sample_kernel[(batch_size,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -213,7 +213,7 @@ def rejection_sample(
|
||||
)
|
||||
|
||||
# Rejection sampling for random sampling requests.
|
||||
rejection_random_sample_kernel[(batch_size, )](
|
||||
rejection_random_sample_kernel[(batch_size,)](
|
||||
output_token_ids,
|
||||
cu_num_draft_tokens,
|
||||
draft_token_ids,
|
||||
@@ -320,7 +320,7 @@ def expand_batch_to_tokens(
|
||||
batch_size = x.shape[0]
|
||||
assert cu_num_tokens.shape[0] == batch_size
|
||||
expanded_x = x.new_empty(num_tokens)
|
||||
expand_kernel[(batch_size, )](
|
||||
expand_kernel[(batch_size,)](
|
||||
expanded_x,
|
||||
x,
|
||||
cu_num_tokens,
|
||||
@@ -368,7 +368,7 @@ def generate_uniform_probs(
|
||||
# https://github.com/pytorch/pytorch/issues/16706. Using float64
|
||||
# mitigates the issue.
|
||||
uniform_probs = torch.rand(
|
||||
(num_tokens, ),
|
||||
(num_tokens,),
|
||||
dtype=torch.float64,
|
||||
device=device,
|
||||
)
|
||||
@@ -464,8 +464,10 @@ def rejection_greedy_sample_kernel(
|
||||
if not rejected:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
target_argmax_id,
|
||||
)
|
||||
if draft_token_id != target_argmax_id:
|
||||
# Reject.
|
||||
rejected = True
|
||||
@@ -474,8 +476,9 @@ def rejection_greedy_sample_kernel(
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@@ -514,12 +517,12 @@ def rejection_random_sample_kernel(
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_prob = 1
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size +
|
||||
draft_token_id)
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
|
||||
)
|
||||
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
|
||||
# NOTE(woosuk): While the draft probability should never be 0,
|
||||
# we check it to avoid NaNs. If it happens to be 0, we reject.
|
||||
@@ -530,15 +533,17 @@ def rejection_random_sample_kernel(
|
||||
# Reject. Use recovered token.
|
||||
rejected = True
|
||||
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
|
||||
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
|
||||
token_id)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
|
||||
)
|
||||
|
||||
if not rejected:
|
||||
# If all tokens are accepted, append the bonus token.
|
||||
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
|
||||
tl.store(
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
|
||||
num_draft_tokens, bonus_token_id)
|
||||
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
|
||||
bonus_token_id,
|
||||
)
|
||||
|
||||
|
||||
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
|
||||
@@ -562,9 +567,7 @@ def expand_kernel(
|
||||
src_val = tl.load(input_ptr + req_idx)
|
||||
src_val = tl.where(src_val == replace_from, replace_to, src_val)
|
||||
offset = tl.arange(0, MAX_NUM_TOKENS)
|
||||
tl.store(output_ptr + start_idx + offset,
|
||||
src_val,
|
||||
mask=offset < num_tokens)
|
||||
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
|
||||
|
||||
|
||||
@triton.jit
|
||||
@@ -595,26 +598,30 @@ def sample_recovered_tokens_kernel(
|
||||
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
|
||||
if NO_DRAFT_PROBS:
|
||||
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
|
||||
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=((vocab_offset < vocab_size) &
|
||||
(vocab_offset != draft_token_id)),
|
||||
other=0)
|
||||
prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
|
||||
other=0,
|
||||
)
|
||||
else:
|
||||
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
|
||||
vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
target_prob = tl.load(target_probs_ptr +
|
||||
(start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0)
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=0,
|
||||
)
|
||||
prob = tl.maximum(target_prob - draft_prob, 0)
|
||||
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
|
||||
# `tl.argmax` will select the maximum value.
|
||||
|
||||
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"))
|
||||
q = tl.load(
|
||||
q_ptr + req_idx * vocab_size + vocab_offset,
|
||||
mask=vocab_offset < vocab_size,
|
||||
other=float("-inf"),
|
||||
)
|
||||
recovered_id = tl.argmax(prob / q, axis=-1)
|
||||
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)
|
||||
|
||||
Reference in New Issue
Block a user