[CI/Build] Update CPU tests to include all "standard" tests (#5481)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-08 23:30:04 +08:00
committed by GitHub
parent 208ce622c7
commit b489fc3c91
14 changed files with 63 additions and 48 deletions

View File

@@ -5,11 +5,11 @@ import pytest
import pytest_asyncio
from transformers import AutoModel, AutoTokenizer, BatchEncoding
from tests.utils import RemoteOpenAIServer
from vllm.sequence import SampleLogprobs
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer
from ...utils import check_logprobs_close
MODEL_NAME = "fixie-ai/ultravox-v0_3"
@@ -39,7 +39,10 @@ def audio(request):
return AudioAsset(request.param)
@pytest.fixture(params=({}, CHUNKED_PREFILL_KWARGS))
@pytest.fixture(params=[
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def server(request, audio_assets):
args = [
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
@@ -185,7 +188,10 @@ def run_multi_audio_test(
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
@pytest.mark.parametrize("vllm_kwargs", [
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int, vllm_kwargs: dict) -> None:
@@ -207,7 +213,10 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("vllm_kwargs", [{}, CHUNKED_PREFILL_KWARGS])
@pytest.mark.parametrize("vllm_kwargs", [
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str,
max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None: