Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,8 +5,10 @@ import torch
|
||||
from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
is_flashinfer_available)
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (
|
||||
apply_top_k_top_p,
|
||||
is_flashinfer_available,
|
||||
)
|
||||
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
@@ -30,19 +32,18 @@ def reset_default_device():
|
||||
|
||||
|
||||
def test_topk_impl_equivalence():
|
||||
|
||||
torch.set_default_device(DEVICE)
|
||||
generator = Generator(device=DEVICE).manual_seed(33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Random top-k values between 1 and 9.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
||||
k = torch.randint(1, 10, (BATCH_SIZE,), generator=generator)
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(
|
||||
torch.randint(0, 2, (BATCH_SIZE, ), generator=generator, dtype=bool),
|
||||
VOCAB_SIZE)
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=bool), VOCAB_SIZE
|
||||
)
|
||||
|
||||
# Top-k only implementation
|
||||
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
@@ -55,7 +56,7 @@ def test_topk_impl_equivalence():
|
||||
|
||||
|
||||
def test_flashinfer_sampler():
|
||||
'''
|
||||
"""
|
||||
This test verifies that the FlashInfer top-k and top-p sampling
|
||||
implementation produces the same results as the Python implementation.
|
||||
|
||||
@@ -63,11 +64,10 @@ def test_flashinfer_sampler():
|
||||
top-p prob renorm (it did provide fused sampling but we cannot compare
|
||||
sampling results due to randomness), so we will compare the probability
|
||||
renormed consequently by top-k and then top-p of FlashInfer implementation.
|
||||
'''
|
||||
"""
|
||||
|
||||
if not FLASHINFER_ENABLED:
|
||||
pytest.skip(
|
||||
"FlashInfer not installed or not available on this platform.")
|
||||
pytest.skip("FlashInfer not installed or not available on this platform.")
|
||||
|
||||
torch.set_default_device(DEVICE)
|
||||
generator = Generator(device=DEVICE).manual_seed(42)
|
||||
@@ -76,23 +76,21 @@ def test_flashinfer_sampler():
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Generate various top-k and top-p values
|
||||
k_values = torch.randint(1, 1000, (BATCH_SIZE, ), generator=generator)
|
||||
p_values = torch.rand(
|
||||
(BATCH_SIZE, ), generator=generator) * 0.5 + 0.5 # range in [0.5, 1.0]
|
||||
k_values = torch.randint(1, 1000, (BATCH_SIZE,), generator=generator)
|
||||
p_values = (
|
||||
torch.rand((BATCH_SIZE,), generator=generator) * 0.5 + 0.5
|
||||
) # range in [0.5, 1.0]
|
||||
|
||||
# Sometimes disable top-k (k=vocab_size)
|
||||
k_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), VOCAB_SIZE)
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool),
|
||||
VOCAB_SIZE,
|
||||
)
|
||||
|
||||
# Sometimes disable top-p (p=1.0)
|
||||
p_values.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=torch.bool), 1.0)
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
|
||||
)
|
||||
|
||||
python_logits = apply_top_k_top_p(
|
||||
logits=logits.clone(),
|
||||
@@ -113,5 +111,6 @@ def test_flashinfer_sampler():
|
||||
)
|
||||
|
||||
# Compare the results
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), \
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
|
||||
"FlashInfer and Python sampling implementations do not match!"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user