Feature/vit attention unification# 23880 (#23978)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
baonudesifeizhai
2025-09-10 09:10:14 -04:00
committed by GitHub
parent 72d30108a0
commit 6cbd41909e
9 changed files with 68 additions and 56 deletions

View File

@@ -23,6 +23,9 @@ def clear_cache():
"""Clear lru cache to ensure each test case runs without caching.
"""
_cached_get_attn_backend.cache_clear()
# Clear xformers availability cache
import vllm.attention.layer as layer_module
layer_module.USE_XFORMERS_OPS = None
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
@@ -33,19 +36,28 @@ def test_mha_attn_platform(device: str):
torch.set_default_dtype(torch.float16)
if device == "cpu":
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
with patch("vllm.attention.selector.current_platform",
CpuPlatform()), \
patch("vllm.platforms.current_platform", CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
elif device == "hip":
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
with patch("vllm.attention.selector.current_platform",
RocmPlatform()), \
patch("vllm.platforms.current_platform", RocmPlatform()), \
patch("vllm.attention.layer.current_platform", RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS