[multi-step] add flashinfer backend (#7928)
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
# Test the AsyncLLMEngine with multi-step-decoding
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
|
||||
from ..models.utils import check_logprobs_close
|
||||
from ..utils import (completions_with_server_args, get_client_text_generations,
|
||||
get_client_text_logprob_generations)
|
||||
@@ -33,8 +34,9 @@ DEFAULT_SERVER_ARGS: List[str] = [
|
||||
@pytest.mark.parametrize("eager_mode", [False, True])
|
||||
@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS)
|
||||
@pytest.mark.parametrize("num_prompts", NUM_PROMPTS)
|
||||
@pytest.mark.parametrize("num_logprobs", [None, 5])
|
||||
@pytest.mark.parametrize("is_async", [False, True])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@pytest.mark.parametrize("is_async", [True])
|
||||
@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_step(
|
||||
example_prompts,
|
||||
@@ -46,6 +48,8 @@ async def test_multi_step(
|
||||
num_prompts: int,
|
||||
is_async: bool,
|
||||
num_logprobs: Optional[int],
|
||||
attention_backend: str,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""Test vLLM engine with multi-step scheduling in an OpenAI-protocol
|
||||
client/server environment.
|
||||
@@ -71,6 +75,8 @@ async def test_multi_step(
|
||||
completions endpoint; `None` -> no logprobs
|
||||
"""
|
||||
|
||||
override_backend_env_variable(monkeypatch, attention_backend)
|
||||
|
||||
prompts = example_prompts
|
||||
if len(prompts) < num_prompts:
|
||||
prompts = prompts * ((num_prompts // len(prompts)) + 1)
|
||||
|
||||
Reference in New Issue
Block a user