Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,8 +11,7 @@ from vllm.utils import STR_BACKEND_ENV_VAR
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@@ -22,46 +21,29 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_FLASH")
|
||||
|
||||
# Set the current platform to ROCm using monkeypatch
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform",
|
||||
RocmPlatform())
|
||||
monkeypatch.setattr("vllm.attention.selector.current_platform", RocmPlatform())
|
||||
|
||||
# Test standard ROCm attention
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16, 16, False)
|
||||
assert (backend.get_name() == "ROCM_FLASH"
|
||||
or backend.get_name() == "TRITON_ATTN")
|
||||
assert backend.get_name() == "ROCM_FLASH" or backend.get_name() == "TRITON_ATTN"
|
||||
|
||||
# MLA test for deepseek related
|
||||
|
||||
# change the attention backend to triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "TRITON_MLA")
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
# If use_mla is true
|
||||
# The selected backend is triton MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
16,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 16, False, use_mla=True)
|
||||
assert backend.get_name() == "TRITON_MLA"
|
||||
|
||||
# change the attention backend to AITER MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, "ROCM_AITER_MLA")
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
# If attention backend is None
|
||||
@@ -70,10 +52,5 @@ def test_selector(monkeypatch: pytest.MonkeyPatch):
|
||||
# The selected backend is ROCM_AITER_MLA
|
||||
m.setenv(STR_BACKEND_ENV_VAR, None)
|
||||
m.setenv("VLLM_ROCM_USE_AITER", "1")
|
||||
backend = get_attn_backend(576,
|
||||
torch.bfloat16,
|
||||
"auto",
|
||||
1,
|
||||
False,
|
||||
use_mla=True)
|
||||
backend = get_attn_backend(576, torch.bfloat16, "auto", 1, False, use_mla=True)
|
||||
assert backend.get_name() == "ROCM_AITER_MLA"
|
||||
|
||||
Reference in New Issue
Block a user