[V1] V1 FlashInfer Attention (#16684)
Signed-off-by: mgoin <mgoin64@gmail.com> Co-authored-by: Aurick Qiao <qiao@aurick.net>
This commit is contained in:
@@ -1,13 +1,21 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
from ...utils import fork_new_process_for_each_test
|
||||
|
||||
def test_cascade_attention(example_system_message, monkeypatch):
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("attn_backend",
|
||||
["FLASH_ATTN_VLLM_V1", "FLASHINFER_VLLM_V1"])
|
||||
def test_cascade_attention(example_system_message, monkeypatch, attn_backend):
|
||||
prompt = "\n<User>: Implement fibonacci sequence in Python.\n<Claude>:"
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)
|
||||
|
||||
llm = LLM(model="Qwen/Qwen2-1.5B-Instruct")
|
||||
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
|
||||
|
||||
Reference in New Issue
Block a user