Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -54,7 +54,7 @@ class RejectionSampler(nn.Module):
bonus_token_ids: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
'''
"""
Args:
metadata:
Metadata for spec decoding.
@@ -81,7 +81,7 @@ class RejectionSampler(nn.Module):
Returns:
output_token_ids (torch.Tensor):
A tensor containing the final output token IDs.
'''
"""
assert metadata.max_spec_len <= MAX_SPEC_LEN
# [num_tokens, vocab_size]
# NOTE(woosuk): `target_logits` can be updated in place inside the
@@ -123,11 +123,11 @@ class RejectionSampler(nn.Module):
"""
output_token_ids_np = output_token_ids.cpu().numpy()
# Create mask for valid tokens.
valid_mask = ((output_token_ids_np != PLACEHOLDER_TOKEN_ID) &
(output_token_ids_np < vocab_size))
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
outputs = [
row[valid_mask[i]].tolist()
for i, row in enumerate(output_token_ids_np)
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
@@ -178,7 +178,7 @@ def rejection_sample(
if not sampling_metadata.all_random:
# Rejection sampling for greedy sampling requests.
target_argmax = target_probs.argmax(dim=-1)
rejection_greedy_sample_kernel[(batch_size, )](
rejection_greedy_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -213,7 +213,7 @@ def rejection_sample(
)
# Rejection sampling for random sampling requests.
rejection_random_sample_kernel[(batch_size, )](
rejection_random_sample_kernel[(batch_size,)](
output_token_ids,
cu_num_draft_tokens,
draft_token_ids,
@@ -320,7 +320,7 @@ def expand_batch_to_tokens(
batch_size = x.shape[0]
assert cu_num_tokens.shape[0] == batch_size
expanded_x = x.new_empty(num_tokens)
expand_kernel[(batch_size, )](
expand_kernel[(batch_size,)](
expanded_x,
x,
cu_num_tokens,
@@ -368,7 +368,7 @@ def generate_uniform_probs(
# https://github.com/pytorch/pytorch/issues/16706. Using float64
# mitigates the issue.
uniform_probs = torch.rand(
(num_tokens, ),
(num_tokens,),
dtype=torch.float64,
device=device,
)
@@ -464,8 +464,10 @@ def rejection_greedy_sample_kernel(
if not rejected:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
target_argmax_id = tl.load(target_argmax_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
target_argmax_id,
)
if draft_token_id != target_argmax_id:
# Reject.
rejected = True
@@ -474,8 +476,9 @@ def rejection_greedy_sample_kernel(
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@@ -514,12 +517,12 @@ def rejection_random_sample_kernel(
if NO_DRAFT_PROBS:
draft_prob = 1
else:
draft_prob = tl.load(draft_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size +
draft_token_id)
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + draft_token_id
)
uniform_prob = tl.load(uniform_probs_ptr + start_idx + pos)
# NOTE(woosuk): While the draft probability should never be 0,
# we check it to avoid NaNs. If it happens to be 0, we reject.
@@ -530,15 +533,17 @@ def rejection_random_sample_kernel(
# Reject. Use recovered token.
rejected = True
token_id = tl.load(recovered_token_ids_ptr + start_idx + pos)
tl.store(output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos,
token_id)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) + pos, token_id
)
if not rejected:
# If all tokens are accepted, append the bonus token.
bonus_token_id = tl.load(bonus_token_ids_ptr + req_idx)
tl.store(
output_token_ids_ptr + req_idx * (max_spec_len + 1) +
num_draft_tokens, bonus_token_id)
output_token_ids_ptr + req_idx * (max_spec_len + 1) + num_draft_tokens,
bonus_token_id,
)
# NOTE(woosuk): Avoid specialization to prevent unnecessary recompilation.
@@ -562,9 +567,7 @@ def expand_kernel(
src_val = tl.load(input_ptr + req_idx)
src_val = tl.where(src_val == replace_from, replace_to, src_val)
offset = tl.arange(0, MAX_NUM_TOKENS)
tl.store(output_ptr + start_idx + offset,
src_val,
mask=offset < num_tokens)
tl.store(output_ptr + start_idx + offset, src_val, mask=offset < num_tokens)
@triton.jit
@@ -595,26 +598,30 @@ def sample_recovered_tokens_kernel(
vocab_offset = tl.arange(0, PADDED_VOCAB_SIZE)
if NO_DRAFT_PROBS:
draft_token_id = tl.load(draft_token_ids_ptr + start_idx + pos)
prob = tl.load(target_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=((vocab_offset < vocab_size) &
(vocab_offset != draft_token_id)),
other=0)
prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=((vocab_offset < vocab_size) & (vocab_offset != draft_token_id)),
other=0,
)
else:
draft_prob = tl.load(draft_probs_ptr + (start_idx + pos) * vocab_size +
vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
target_prob = tl.load(target_probs_ptr +
(start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0)
draft_prob = tl.load(
draft_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
target_prob = tl.load(
target_probs_ptr + (start_idx + pos) * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=0,
)
prob = tl.maximum(target_prob - draft_prob, 0)
# NOTE(woosuk): We don't need `prob = prob / tl.sum(prob)` here because
# `tl.argmax` will select the maximum value.
q = tl.load(q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"))
q = tl.load(
q_ptr + req_idx * vocab_size + vocab_offset,
mask=vocab_offset < vocab_size,
other=float("-inf"),
)
recovered_id = tl.argmax(prob / q, axis=-1)
tl.store(output_token_ids_ptr + start_idx + pos, recovered_id)