Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,8 +10,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
# isort: off
|
||||
from vllm.v1.sample.tpu.sampler import (apply_top_k_top_p as
|
||||
apply_top_k_top_p_tpu)
|
||||
from vllm.v1.sample.tpu.sampler import apply_top_k_top_p as apply_top_k_top_p_tpu
|
||||
# isort: on
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
@@ -30,11 +29,10 @@ def test_topk_equivalence_to_native_impl():
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
||||
|
||||
# Random top-k values between 1 and 10.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE, ))
|
||||
k = torch.randint(1, 10, (BATCH_SIZE,))
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
|
||||
VOCAB_SIZE)
|
||||
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), VOCAB_SIZE)
|
||||
|
||||
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
@@ -50,15 +48,13 @@ def test_topp_result_sums_past_p():
|
||||
probs = logits.softmax(dim=-1)
|
||||
|
||||
# Random top-p values between 0 and 1.
|
||||
p = torch.rand((BATCH_SIZE, ))
|
||||
p = torch.rand((BATCH_SIZE,))
|
||||
|
||||
# Set p=1 for ~50% of requests in the batch (top-p disabled).
|
||||
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool), 1)
|
||||
p.masked_fill_(torch.randint(0, 2, (BATCH_SIZE,), dtype=bool), 1)
|
||||
|
||||
no_op_k = torch.tensor([VOCAB_SIZE])
|
||||
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||
k=no_op_k,
|
||||
p=p)
|
||||
logits_masked = apply_top_k_top_p_tpu(logits=logits.clone(), k=no_op_k, p=p)
|
||||
|
||||
# Verify that the masked logit's probability sums to at least p.
|
||||
probs.masked_fill_(logits_masked.isinf(), 0)
|
||||
@@ -72,16 +68,16 @@ def test_topp_result_sums_past_p():
|
||||
|
||||
def test_topp_basic():
|
||||
with torch.device(xm.xla_device()):
|
||||
logits = torch.tensor([[math.log(0.2),
|
||||
math.log(0.3),
|
||||
math.log(0.5)],
|
||||
[math.log(0.5),
|
||||
math.log(0.1),
|
||||
math.log(0.4)]])
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
||||
]
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||
k=torch.tensor([3, 3]),
|
||||
p=torch.tensor([0.79, 0.79]))
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([0.79, 0.79])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
@@ -94,16 +90,16 @@ def test_topp_basic():
|
||||
|
||||
def test_topp_select_all():
|
||||
with torch.device(xm.xla_device()):
|
||||
logits = torch.tensor([[math.log(0.2),
|
||||
math.log(0.3),
|
||||
math.log(0.5)],
|
||||
[math.log(0.5),
|
||||
math.log(0.1),
|
||||
math.log(0.4)]])
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
||||
]
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||
k=torch.tensor([3, 3]),
|
||||
p=torch.tensor([1.0, 1.0]))
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([3, 3]), p=torch.tensor([1.0, 1.0])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
@@ -114,14 +110,12 @@ def test_topp_with_ties():
|
||||
with torch.device(xm.xla_device()):
|
||||
# Input has multiple math.log(0.3).
|
||||
logits = torch.tensor(
|
||||
[[math.log(0.3),
|
||||
math.log(0.3),
|
||||
math.log(0.3),
|
||||
math.log(0.1)]])
|
||||
[[math.log(0.3), math.log(0.3), math.log(0.3), math.log(0.1)]]
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||
k=torch.tensor([4]),
|
||||
p=torch.tensor([0.2]))
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([4]), p=torch.tensor([0.2])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
@@ -135,17 +129,17 @@ def test_topp_with_ties():
|
||||
|
||||
def test_both_topk_topp():
|
||||
with torch.device(xm.xla_device()):
|
||||
logits = torch.tensor([[math.log(0.2),
|
||||
math.log(0.3),
|
||||
math.log(0.5)],
|
||||
[math.log(0.5),
|
||||
math.log(0.1),
|
||||
math.log(0.4)]])
|
||||
logits = torch.tensor(
|
||||
[
|
||||
[math.log(0.2), math.log(0.3), math.log(0.5)],
|
||||
[math.log(0.5), math.log(0.1), math.log(0.4)],
|
||||
]
|
||||
)
|
||||
|
||||
# Set k=1 for the first batch.
|
||||
result = apply_top_k_top_p_tpu(logits=logits.clone(),
|
||||
k=torch.tensor([1, 3]),
|
||||
p=torch.tensor([0.79, 0.79]))
|
||||
result = apply_top_k_top_p_tpu(
|
||||
logits=logits.clone(), k=torch.tensor([1, 3]), p=torch.tensor([0.79, 0.79])
|
||||
)
|
||||
|
||||
torch_xla.sync()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user