[V1][TPU] Speed up top-k on TPU by using torch.topk (#15242)

Signed-off-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
Hyesoo Yang
2025-03-20 19:19:40 -07:00
committed by GitHub
parent 6edbfa924d
commit 47195057e9
3 changed files with 29 additions and 4 deletions

View File

@@ -39,7 +39,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
sampling_params = SamplingParams(
temperature=0.7,
# top_p=0.6, # TODO too slow!
# top_k=10,
top_k=10,
min_p=0.2,
max_tokens=16)
s = time()
@@ -49,6 +49,7 @@ def test_sampler_compilation(model_name: str, monkeypatch):
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params = SamplingParams(temperature=0.1,
top_k=12,
min_p=0.8,
max_tokens=24)
s = time()