[multi-step] add flashinfer backend (#7928)

This commit is contained in:
William Lin
2024-09-12 11:16:22 -07:00
committed by GitHub
parent f2e263b801
commit a6c0f3658d
9 changed files with 371 additions and 84 deletions

View File

@@ -1,9 +1,10 @@
# Test the AsyncLLMEngine with multi-step-decoding
from typing import List, Optional
import pytest
from tests.kernels.utils import override_backend_env_variable
from ..models.utils import check_logprobs_close
from ..utils import (completions_with_server_args, get_client_text_generations,
get_client_text_logprob_generations)
@@ -33,8 +34,9 @@ DEFAULT_SERVER_ARGS: List[str] = [
@pytest.mark.parametrize("eager_mode", [False, True])
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
@pytest.mark.parametrize("num_logprobs", [None, 5])
@pytest.mark.parametrize("is_async", [False, True])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize("is_async", [True])
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
@pytest.mark.asyncio
async def test_multi_step(
example_prompts,
@@ -46,6 +48,8 @@ async def test_multi_step(
num_prompts: int,
is_async: bool,
num_logprobs: Optional[int],
attention_backend: str,
monkeypatch,
) -> None:
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
client/server environment.
@@ -71,6 +75,8 @@ async def test_multi_step(
completions endpoint; `None` -> no logprobs
"""
override_backend_env_variable(monkeypatch, attention_backend)
prompts = example_prompts
if len(prompts) < num_prompts:
prompts = prompts * ((num_prompts // len(prompts)) + 1)