Feature/vit attention unification# 23880 (#23978)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -23,6 +23,9 @@ def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
# Clear xformers availability cache
|
||||
import vllm.attention.layer as layer_module
|
||||
layer_module.USE_XFORMERS_OPS = None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||
@@ -33,19 +36,28 @@ def test_mha_attn_platform(device: str):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()), \
|
||||
patch("vllm.platforms.current_platform", CpuPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()), \
|
||||
patch("vllm.platforms.current_platform", RocmPlatform()), \
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
else:
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
|
||||
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
|
||||
|
||||
Reference in New Issue
Block a user