[platform] Allow platform specify attention backend (#11609)

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
wangxiyuan
2025-01-09 21:46:50 +08:00
committed by GitHub
parent 65097ca0af
commit 405eb8e396
10 changed files with 164 additions and 175 deletions

View File

@@ -1,10 +1,10 @@
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.attention.selector import _cached_get_attn_backend, get_attn_backend
from vllm.platforms.cpu import CpuPlatform
from vllm.platforms.cuda import CudaPlatform
from vllm.platforms.openvino import OpenVinoPlatform
@@ -12,6 +12,13 @@ from vllm.platforms.rocm import RocmPlatform
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@pytest.fixture(autouse=True)
def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER", "OPENVINO"])
@pytest.mark.parametrize("device", ["cpu", "openvino", "hip", "cuda"])
@@ -24,67 +31,70 @@ def test_env(name: str, device: str, monkeypatch):
if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform",
OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
OpenVinoPlatform()), patch.dict('sys.modules',
{'openvino': Mock()}):
backend = get_attn_backend(16, torch.float16, torch.float16, 16,
False)
assert backend.get_name() == "OPENVINO"
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
assert backend.get_name() == name
def test_flash_attn(monkeypatch):
"""Test FlashAttn validation."""
# TODO: When testing for v1, pipe in `use_v1` as an argument to
# which_attn_to_use
# get_attn_backend
override_backend_env_variable(monkeypatch, STR_FLASH_ATTN_VAL)
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, "fp8", 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 8, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(17, torch.float16, None, 16, False)
assert backend.get_name() != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL
def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, torch.float16, None, 16, False)
get_attn_backend(16, torch.float16, None, 16, False)