[V1] V1 Enablement Oracle (#13726)
Signed-off-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: rshaw@neuralmagic.com <rshaw@neuralmagic.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -23,12 +23,14 @@ def clear_cache():
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
|
||||
def test_env(name: str, device: str, monkeypatch):
|
||||
def test_env(name: str, use_v1: bool, device: str, monkeypatch):
|
||||
"""Test that the attention selector can be set via environment variable.
|
||||
Note that we do not test FlashAttn because it is the default backend.
|
||||
"""
|
||||
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
override_backend_env_variable(monkeypatch, name)
|
||||
|
||||
if device == "cpu":
|
||||
@@ -40,7 +42,8 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.get_name() == "ROCM_FLASH"
|
||||
EXPECTED = "ROCM_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||
assert backend.get_name() == EXPECTED
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
OpenVinoPlatform()), patch.dict('sys.modules',
|
||||
@@ -54,7 +57,8 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
assert backend.get_name() == name
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == EXPECTED
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch):
|
||||
@@ -95,13 +99,23 @@ def test_flash_attn(monkeypatch):
|
||||
assert backend.get_name() != STR_FLASH_ATTN_VAL
|
||||
|
||||
|
||||
def test_invalid_env(monkeypatch):
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
def test_invalid_env(use_v1: bool, monkeypatch):
|
||||
"""Ignore the invalid env variable if it is set."""
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
|
||||
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(32, torch.float16, None, 16, False)
|
||||
assert backend.get_name() == "FLASH_ATTN"
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else "FLASH_ATTN"
|
||||
assert backend.get_name() == EXPECTED
|
||||
|
||||
# when block size == 16, backend will fall back to XFORMERS
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() == "XFORMERS"
|
||||
# this behavior is not yet supported on V1.
|
||||
if use_v1:
|
||||
# TODO: support fallback on V1!
|
||||
# https://github.com/vllm-project/vllm/issues/14524
|
||||
pass
|
||||
else:
|
||||
backend = get_attn_backend(16, torch.float16, None, 16, False)
|
||||
assert backend.get_name() == "XFORMERS"
|
||||
|
||||
Reference in New Issue
Block a user