[V1][Sampler] Faster top-k only implementation (#15478)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
37
tests/v1/sample/test_topk_topp_sampler.py
Normal file
37
tests/v1/sample/test_topk_topp_sampler.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import torch
|
||||
from torch import Generator
|
||||
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
|
||||
DEVICE = "cuda"
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
VOCAB_SIZE = 128 * 1024
|
||||
|
||||
|
||||
def test_topk_impl_equivalance():
|
||||
|
||||
with torch.device(DEVICE):
|
||||
generator = Generator(device=DEVICE).manual_seed(33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE), generator=generator)
|
||||
|
||||
# Random top-k values between 1 and 9.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE, ), generator=generator)
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(
|
||||
torch.randint(0,
|
||||
2, (BATCH_SIZE, ),
|
||||
generator=generator,
|
||||
dtype=bool), VOCAB_SIZE)
|
||||
|
||||
# Top-k only implementation
|
||||
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
# Top-p + top-k
|
||||
no_op_top_p = torch.tensor([1.0])
|
||||
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
|
||||
assert torch.allclose(result1, result2)
|
||||
Reference in New Issue
Block a user