[Kernel] Use flashinfer for decoding (#4353)

Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>
This commit is contained in:
Lily Liu
2024-05-03 15:51:27 -07:00
committed by GitHub
parent f8e7adda21
commit 43c413ec57
15 changed files with 600 additions and 53 deletions

View File

@@ -2,12 +2,15 @@
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
"""
import os
import pytest
MODELS = [
"facebook/opt-125m",
"meta-llama/Llama-2-7b-hf",
]
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
@pytest.mark.parametrize("model", MODELS)
@@ -23,11 +26,18 @@ def test_models(
max_tokens: int,
enforce_eager: bool,
) -> None:
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
pytest.skip("Skipping non-eager test for FlashInferBackend.")
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
del hf_model
vllm_model = vllm_runner(model, dtype=dtype, enforce_eager=enforce_eager)
vllm_model = vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7)
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)
del vllm_model