[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)

Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
Mengqing Cao
2024-11-19 11:22:26 +08:00
committed by GitHub
parent 7eb719df13
commit 8c1fb50705
14 changed files with 99 additions and 69 deletions

View File

@@ -5,6 +5,7 @@ import torch
from tests.kernels.utils import override_backend_env_variable
from vllm.attention.selector import which_attn_to_use
from vllm.platforms import cpu, cuda, openvino, rocm
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
@@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.current_platform.is_cpu",
return_value=True):
with patch("vllm.attention.selector.current_platform",
cpu.CpuPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform.is_rocm",
return_value=True):
with patch("vllm.attention.selector.current_platform",
rocm.RocmPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.current_platform.is_openvino",
return_value=True):
with patch("vllm.attention.selector.current_platform",
openvino.OpenVinoPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
with patch("vllm.attention.selector.current_platform",
cuda.CudaPlatform()):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name