diff --git a/tests/test_regression.py b/tests/test_regression.py index 978e07839..a38b4428d 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -12,6 +12,7 @@ import gc import pytest import torch +from tests.utils import large_gpu_mark from vllm import LLM, SamplingParams from vllm.platforms import current_platform @@ -32,10 +33,21 @@ def test_duplicated_ignored_sequence_group(): assert len(prompts) == len(outputs) -def test_max_tokens_none(): +@pytest.mark.parametrize( + "model", + [ + pytest.param( + "distilbert/distilgpt2", + marks=[ + *([large_gpu_mark(min_gb=80)] if current_platform.is_rocm() else []), + ], + ), + ], +) +def test_max_tokens_none(model): sampling_params = SamplingParams(temperature=0.01, top_p=0.1, max_tokens=None) llm = LLM( - model="distilbert/distilgpt2", + model=model, max_num_batched_tokens=4096, tensor_parallel_size=1, )