[V1][TPU] Enable Top K (#15489)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Hyesoo Yang <hyeygit@gmail.com> Co-authored-by: Hyesoo Yang <hyeygit@gmail.com>
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import random
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, envs
|
||||
@@ -39,3 +41,19 @@ def test_sampler_different(model_name: str):
|
||||
# Unsupported `seed` param.
|
||||
sampling_params = SamplingParams(temperature=0.3, seed=42)
|
||||
output2 = llm.generate(prompts, sampling_params)
|
||||
|
||||
# Batch-case with TopK
|
||||
for B in [4, 16]:
|
||||
p = prompts * B
|
||||
sampling_params = [
|
||||
SamplingParams(
|
||||
temperature=0.1,
|
||||
min_p=0.8,
|
||||
max_tokens=64,
|
||||
# Vary number of ks
|
||||
top_k=random.randint(4, 12)) for _ in range(B)
|
||||
]
|
||||
# Make sure first two reqs have the same K
|
||||
sampling_params[0] = sampling_params[1]
|
||||
output = llm.generate(p, sampling_params)
|
||||
assert output[0].outputs[0].text == output[1].outputs[0].text
|
||||
|
||||
Reference in New Issue
Block a user