[Misc] Add chunked-prefill support on FlashInfer. (#9781)
This commit is contained in:
@@ -11,6 +11,8 @@ from contextlib import nullcontext
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
|
||||
from ..models.utils import check_logprobs_close, check_outputs_equal
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
@@ -28,6 +30,7 @@ MODELS = [
|
||||
# NOTE: Increasing this in this suite will fail CI because we currently cannot
|
||||
# reset distributed env properly. Use a value > 1 just when you test.
|
||||
@pytest.mark.parametrize("tensor_parallel_size", [1])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
@@ -38,11 +41,15 @@ def test_models(
|
||||
chunked_prefill_token_size: int,
|
||||
enforce_eager: bool,
|
||||
tensor_parallel_size: int,
|
||||
attention_backend: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""
|
||||
Checks exact match decode between huggingface model and vllm runner with
|
||||
chunked prefill.
|
||||
"""
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
max_num_seqs = chunked_prefill_token_size
|
||||
max_num_batched_tokens = chunked_prefill_token_size
|
||||
|
||||
@@ -71,13 +78,18 @@ def test_models(
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
def test_models_distributed(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
distributed_executor_backend: str,
|
||||
attention_backend: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
if (model == "meta-llama/Llama-2-7b-hf"
|
||||
and distributed_executor_backend == "ray"):
|
||||
# test ray adag
|
||||
|
||||
Reference in New Issue
Block a user