[Platform][Refactor] Extract func get_default_attn_backend to Platform (#10358)
Signed-off-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from tests.kernels.utils import override_backend_env_variable
|
||||
from vllm.attention.selector import which_attn_to_use
|
||||
from vllm.platforms import cpu, cuda, openvino, rocm
|
||||
from vllm.utils import STR_FLASH_ATTN_VAL, STR_INVALID_VAL
|
||||
|
||||
|
||||
@@ -19,26 +20,28 @@ def test_env(name: str, device: str, monkeypatch):
|
||||
override_backend_env_variable(monkeypatch, name)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform.is_cpu",
|
||||
return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
cpu.CpuPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "TORCH_SDPA"
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform.is_rocm",
|
||||
return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
rocm.RocmPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "ROCM_FLASH"
|
||||
elif device == "openvino":
|
||||
with patch("vllm.attention.selector.current_platform.is_openvino",
|
||||
return_value=True):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
openvino.OpenVinoPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == "OPENVINO"
|
||||
else:
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
cuda.CudaPlatform()):
|
||||
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
|
||||
False)
|
||||
assert backend.name == name
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user