[V1][TPU] Enable Top K (#15489)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
@@ -5,7 +5,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_tpu
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import (apply_top_k_top_p,
|
||||
apply_top_k_top_p_tpu)
|
||||
|
||||
if not current_platform.is_tpu():
|
||||
pytest.skip("This test needs a TPU.", allow_module_level=True)
|
||||
@@ -16,6 +17,25 @@ VOCAB_SIZE = 128 * 1024
|
||||
TOLERANCE = 1e-6
|
||||
|
||||
|
||||
def test_topk_equivalence_to_native_impl():
|
||||
with torch.device(xm.xla_device()):
|
||||
xm.set_rng_state(seed=33)
|
||||
|
||||
logits = torch.rand((BATCH_SIZE, VOCAB_SIZE))
|
||||
|
||||
# Random top-k values between 1 and 10.
|
||||
k = torch.randint(1, 10, (BATCH_SIZE, ))
|
||||
|
||||
# Set k=vocab_size for ~50% of requests in the batch (top-k disabled).
|
||||
k.masked_fill_(torch.randint(0, 2, (BATCH_SIZE, ), dtype=bool),
|
||||
VOCAB_SIZE)
|
||||
|
||||
result_tpu = apply_top_k_top_p_tpu(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
result_native = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
assert torch.allclose(result_native, result_tpu)
|
||||
|
||||
|
||||
def test_topp_result_sums_past_p():
|
||||
with torch.device(xm.xla_device()):
|
||||
xm.set_rng_state(seed=33)
|
||||
|
||||
Reference in New Issue
Block a user