Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,6 +5,7 @@ Test:
|
||||
|
||||
* Tests for MultiHeadAttention layer
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -21,11 +22,11 @@ from vllm.platforms.rocm import RocmPlatform
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_cache():
|
||||
"""Clear lru cache to ensure each test case runs without caching.
|
||||
"""
|
||||
"""Clear lru cache to ensure each test case runs without caching."""
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
# Clear xformers availability cache
|
||||
import vllm.attention.layer as layer_module
|
||||
|
||||
layer_module.USE_XFORMERS_OPS = None
|
||||
|
||||
|
||||
@@ -37,49 +38,63 @@ def test_mha_attn_platform(device: str):
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CpuPlatform()):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CpuPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CpuPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
RocmPlatform()):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", RocmPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", RocmPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.TORCH_SDPA
|
||||
else:
|
||||
# Test CUDA with head_size=64 (divisible by 32)
|
||||
# - should use vLLM's FlashAttention
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 64, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA not available
|
||||
# - should use xformers
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=False):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
patch(
|
||||
"vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=False,
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.XFORMERS
|
||||
|
||||
# Test CUDA with head_size=72 (not divisible by 32)
|
||||
# - with upstream FA available
|
||||
# - should use upstream FA
|
||||
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
|
||||
patch("vllm.model_executor.models.vision.current_platform",
|
||||
CudaPlatform()), \
|
||||
patch("vllm.attention.layer.check_upstream_fa_availability",
|
||||
return_value=True), \
|
||||
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
|
||||
{
|
||||
'flash_attn_varlen_func': lambda *args, **kwargs: None
|
||||
})()}):
|
||||
with (
|
||||
patch("vllm.attention.layer.current_platform", CudaPlatform()),
|
||||
patch("vllm.model_executor.models.vision.current_platform", CudaPlatform()),
|
||||
patch(
|
||||
"vllm.attention.layer.check_upstream_fa_availability", return_value=True
|
||||
),
|
||||
patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"flash_attn": type(
|
||||
"MockFlashAttn",
|
||||
(),
|
||||
{"flash_attn_varlen_func": lambda *args, **kwargs: None},
|
||||
)()
|
||||
},
|
||||
),
|
||||
):
|
||||
attn = MultiHeadAttention(16, 72, scale=1)
|
||||
assert attn.attn_backend == _Backend.FLASH_ATTN
|
||||
|
||||
@@ -108,9 +123,11 @@ NUM_HEADS = [1, 16]
|
||||
NUM_KV_HEADS = [1]
|
||||
HEAD_SIZES = [64, 80]
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
||||
DTYPES = (
|
||||
[torch.half, torch.bfloat16, torch.float]
|
||||
if not current_platform.is_rocm()
|
||||
else [torch.half, torch.bfloat16]
|
||||
)
|
||||
CUDA_DEVICES = ["cuda"]
|
||||
|
||||
|
||||
@@ -138,10 +155,9 @@ def test_mha_attn_forward(
|
||||
k = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
v = torch.randn(batch_size, seq_len, num_kv_heads * head_size)
|
||||
scale = 1.0 / head_size**0.5
|
||||
attn = MultiHeadAttention(num_heads,
|
||||
head_size,
|
||||
scale=scale,
|
||||
num_kv_heads=num_kv_heads)
|
||||
attn = MultiHeadAttention(
|
||||
num_heads, head_size, scale=scale, num_kv_heads=num_kv_heads
|
||||
)
|
||||
output = attn(q, k, v)
|
||||
|
||||
assert num_heads % num_kv_heads == 0
|
||||
|
||||
Reference in New Issue
Block a user