[V0 Deprecation] Remove VLLM_USE_V1 from tests (#26341)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -80,7 +80,6 @@ def test_env(
|
||||
):
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, name)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
|
||||
@@ -212,30 +211,21 @@ def test_env(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
||||
def test_fp32_fallback(
|
||||
device: str,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
def test_fp32_fallback(device: str):
|
||||
"""Test attention backend selection with fp32."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, None, 16)
|
||||
assert backend.get_name() == "FLEX_ATTENTION"
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test FlashAttn validation."""
|
||||
# TODO: When testing for v1, pipe in `use_v1` as an argument to
|
||||
# get_attn_backend
|
||||
|
||||
pytest.skip(
|
||||
"Skipping as current backend selector does not "
|
||||
"handle fallbacks when a backend is set via env var."
|
||||
@@ -289,7 +279,6 @@ def test_invalid_env(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.context() as m,
|
||||
patch("vllm.attention.selector.current_platform", CudaPlatform()),
|
||||
):
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, STR_INVALID_VAL)
|
||||
|
||||
# Should raise ValueError for invalid backend
|
||||
|
||||
Reference in New Issue
Block a user