[Hardware][CPU] Support chunked-prefill and prefix-caching on CPU (#10355)

Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
Li, Jiang
2024-11-20 18:57:39 +08:00
committed by GitHub
parent d5b28447e0
commit 63f1fde277
8 changed files with 558 additions and 368 deletions

View File

@@ -12,6 +12,7 @@ from contextlib import nullcontext
import pytest
from tests.kernels.utils import override_backend_env_variable
from vllm.platforms import current_platform
from ..models.utils import check_logprobs_close, check_outputs_equal
from ..utils import multi_gpu_test
@@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache(
# NOTE: Increasing this in this suite will fail CI because we currently cannot
# reset distributed env properly. Use a value > 1 just when you test.
@pytest.mark.parametrize("tensor_parallel_size", [1])
@pytest.mark.parametrize("dtype", ["half"])
def test_with_prefix_caching(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
tensor_parallel_size: int,
dtype: str,
) -> None:
"""
Checks exact match decode with and without prefix caching
@@ -233,7 +236,7 @@ def test_with_prefix_caching(
for enable in (True, False):
with vllm_runner(
model,
dtype="half",
dtype=dtype,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=True,
enable_prefix_caching=enable,
@@ -260,3 +263,61 @@ def test_with_prefix_caching(
name_0="w/o prefix caching",
name_1="with prefix caching",
)
@pytest.mark.parametrize("model", ["facebook/opt-125m"])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_models_cpu(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
chunked_prefill_token_size: int,
enforce_eager: bool,
attention_backend: str,
monkeypatch,
) -> None:
test_models(
hf_runner,
vllm_runner,
example_prompts,
model,
dtype,
max_tokens,
chunked_prefill_token_size,
enforce_eager,
1,
attention_backend,
monkeypatch,
)
@pytest.mark.parametrize("max_tokens", [16])
@pytest.mark.parametrize("enforce_eager", [False])
@pytest.mark.parametrize("chunk_size", [30, 32])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.cpu_model
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only")
def test_with_prefix_caching_cpu(
vllm_runner,
max_tokens: int,
enforce_eager: bool,
chunk_size: int,
dtype: str,
) -> None:
test_with_prefix_caching(
vllm_runner,
max_tokens,
enforce_eager,
chunk_size,
1,
dtype,
)