2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2025-02-02 14:58:18 -05:00
|
|
|
|
2025-03-22 17:06:39 -04:00
|
|
|
from unittest.mock import patch
|
2024-05-22 15:55:56 -07:00
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
|
2025-12-17 12:49:59 -05:00
|
|
|
from vllm.config import AttentionConfig, VllmConfig, set_current_vllm_config
|
2025-11-20 14:45:56 -06:00
|
|
|
from vllm.platforms import current_platform
|
2024-12-30 20:24:45 +08:00
|
|
|
from vllm.platforms.cpu import CpuPlatform
|
|
|
|
|
from vllm.platforms.cuda import CudaPlatform
|
|
|
|
|
from vllm.platforms.rocm import RocmPlatform
|
2026-01-09 16:10:24 -05:00
|
|
|
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
|
|
|
|
from vllm.v1.attention.selector import _cached_get_attn_backend, get_attn_backend
|
2024-05-22 15:55:56 -07:00
|
|
|
|
|
|
|
|
|
2025-01-09 21:46:50 +08:00
|
|
|
@pytest.fixture(autouse=True)
|
|
|
|
|
def clear_cache():
|
|
|
|
|
"""Clear lru cache to ensure each test case runs without caching."""
|
|
|
|
|
_cached_get_attn_backend.cache_clear()
|
|
|
|
|
|
|
|
|
|
|
2025-04-23 00:31:13 +08:00
|
|
|
# Define MLA and non-MLA backends separately
|
|
|
|
|
DEVICE_MLA_BACKENDS = {
|
2025-09-12 18:30:07 -04:00
|
|
|
"cuda": [
|
|
|
|
|
"TRITON_MLA",
|
|
|
|
|
"FLASHMLA",
|
|
|
|
|
"FLASHINFER_MLA",
|
|
|
|
|
"FLASH_ATTN_MLA",
|
|
|
|
|
"CUTLASS_MLA",
|
|
|
|
|
],
|
2025-04-23 00:31:13 +08:00
|
|
|
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
|
|
|
|
"cpu": [],
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DEVICE_REGULAR_ATTN_BACKENDS = {
|
2025-11-23 20:18:55 -08:00
|
|
|
"cuda": ["FLASHINFER", "FLASH_ATTN"],
|
2025-10-08 15:00:25 -04:00
|
|
|
"hip": ["ROCM_ATTN"],
|
2025-11-12 09:43:06 +08:00
|
|
|
"cpu": ["CPU_ATTN"],
|
2025-04-23 00:31:13 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DEVICE_MLA_BLOCK_SIZES = {
|
|
|
|
|
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
|
|
|
|
"hip": [16, 1], # HIP requires special handling for block_size=1
|
2025-07-06 08:48:13 -07:00
|
|
|
# "cpu": [16] # CPU uses fixed block size from test cases
|
|
|
|
|
"cpu": [], # FIXME(woosuk): Temporarily disable CPU tests
|
2025-04-23 00:31:13 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_params():
|
2025-11-20 14:45:56 -06:00
|
|
|
is_rocm = current_platform.is_rocm()
|
2025-04-23 00:31:13 +08:00
|
|
|
params = []
|
2025-11-20 14:45:56 -06:00
|
|
|
device_list = ["cuda", "cpu"] if not is_rocm else ["hip", "cpu"]
|
2025-04-23 00:31:13 +08:00
|
|
|
for use_mla in [True, False]:
|
2025-11-20 14:45:56 -06:00
|
|
|
for device in device_list:
|
2025-04-23 00:31:13 +08:00
|
|
|
backends = (
|
|
|
|
|
DEVICE_MLA_BACKENDS[device]
|
|
|
|
|
if use_mla
|
|
|
|
|
else DEVICE_REGULAR_ATTN_BACKENDS[device]
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-04-23 00:31:13 +08:00
|
|
|
for name in backends:
|
|
|
|
|
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [16]
|
|
|
|
|
for block_size in block_sizes:
|
|
|
|
|
params.append(
|
|
|
|
|
pytest.param(
|
|
|
|
|
device,
|
|
|
|
|
name,
|
|
|
|
|
use_mla,
|
|
|
|
|
block_size,
|
|
|
|
|
id=f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}",
|
2025-10-05 15:06:22 +01:00
|
|
|
)
|
2025-04-23 00:31:13 +08:00
|
|
|
)
|
|
|
|
|
return params
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("device, name, use_mla, block_size", generate_params())
|
2025-12-17 12:49:59 -05:00
|
|
|
def test_backend_selection(
|
2025-04-23 00:31:13 +08:00
|
|
|
device: str,
|
2025-03-17 11:35:57 +08:00
|
|
|
name: str,
|
2025-04-23 00:31:13 +08:00
|
|
|
use_mla: bool,
|
|
|
|
|
block_size: int,
|
2025-03-17 11:35:57 +08:00
|
|
|
):
|
2025-04-23 00:31:13 +08:00
|
|
|
"""Test attention backend selection with valid device-backend pairs."""
|
2025-12-17 12:49:59 -05:00
|
|
|
# Create AttentionConfig with the specified backend
|
|
|
|
|
attention_config = AttentionConfig(backend=AttentionBackendEnum[name])
|
|
|
|
|
vllm_config = VllmConfig(attention_config=attention_config)
|
2025-03-17 11:35:57 +08:00
|
|
|
|
2025-12-17 12:49:59 -05:00
|
|
|
with set_current_vllm_config(vllm_config):
|
2025-03-17 11:35:57 +08:00
|
|
|
if device == "cpu":
|
2025-10-09 18:46:19 +08:00
|
|
|
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
2025-09-24 00:12:14 +02:00
|
|
|
backend = get_attn_backend(16, torch.float16, None, block_size)
|
2025-11-12 09:43:06 +08:00
|
|
|
assert backend.get_name() == "CPU_ATTN"
|
2025-04-23 00:31:13 +08:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
elif device == "hip":
|
2025-10-09 18:46:19 +08:00
|
|
|
with patch("vllm.platforms.current_platform", RocmPlatform()):
|
2025-04-23 00:31:13 +08:00
|
|
|
if use_mla:
|
2025-09-04 05:47:59 -04:00
|
|
|
# ROCm MLA backend logic:
|
|
|
|
|
# - TRITON_MLA: supported when block_size != 1
|
|
|
|
|
# - ROCM_AITER_MLA: supported when block_size == 1
|
|
|
|
|
# If backend is forced but doesn't match block_size,
|
|
|
|
|
# should raise ValueError
|
|
|
|
|
|
|
|
|
|
if name == "TRITON_MLA" and block_size == 1:
|
|
|
|
|
# TRITON_MLA doesn't support block_size == 1
|
2025-04-23 00:31:13 +08:00
|
|
|
with pytest.raises(ValueError) as exc_info:
|
|
|
|
|
get_attn_backend(
|
2025-09-04 05:47:59 -04:00
|
|
|
16, torch.float16, None, block_size, use_mla=use_mla
|
|
|
|
|
)
|
|
|
|
|
assert f"The selected backend, {name}" in str(exc_info.value)
|
|
|
|
|
else:
|
|
|
|
|
# Valid backend-block_size combination
|
|
|
|
|
backend = get_attn_backend(
|
|
|
|
|
16, torch.float16, None, block_size, use_mla=use_mla
|
|
|
|
|
)
|
2025-09-25 13:37:50 -04:00
|
|
|
expected = name
|
2025-09-04 05:47:59 -04:00
|
|
|
assert backend.get_name() == expected
|
2025-04-23 00:31:13 +08:00
|
|
|
else:
|
|
|
|
|
backend = get_attn_backend(
|
|
|
|
|
16, torch.float16, None, block_size, use_mla=use_mla
|
|
|
|
|
)
|
2025-10-08 15:00:25 -04:00
|
|
|
expected = "ROCM_ATTN"
|
2025-04-23 00:31:13 +08:00
|
|
|
assert backend.get_name() == expected
|
|
|
|
|
|
|
|
|
|
elif device == "cuda":
|
2025-10-09 18:46:19 +08:00
|
|
|
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
2025-11-11 06:40:44 -06:00
|
|
|
capability = torch.cuda.get_device_capability()
|
2025-04-23 00:31:13 +08:00
|
|
|
if use_mla:
|
2025-09-04 05:47:59 -04:00
|
|
|
# CUDA MLA backend logic:
|
|
|
|
|
# - CUTLASS_MLA: only supported with block_size == 128
|
2025-11-11 06:40:44 -06:00
|
|
|
# and Blackwell GPUs (SM 10.x), V1 only
|
2025-09-12 18:30:07 -04:00
|
|
|
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
2025-11-11 06:40:44 -06:00
|
|
|
# (SM 10.x), V1 only
|
2025-09-04 05:47:59 -04:00
|
|
|
# - FLASHMLA: only supported with block_size == 64
|
|
|
|
|
# - FLASH_ATTN_MLA: V1 only
|
|
|
|
|
# - TRITON_MLA: fallback for other cases
|
|
|
|
|
|
|
|
|
|
if name == "CUTLASS_MLA":
|
2025-09-20 17:56:30 -07:00
|
|
|
if block_size != 128:
|
2025-09-04 05:47:59 -04:00
|
|
|
# CUTLASS_MLA only supports block_size == 128
|
|
|
|
|
pytest.skip("CUTLASS_MLA only supports block_size 128")
|
2025-11-11 06:40:44 -06:00
|
|
|
if capability[0] != 10:
|
|
|
|
|
pytest.skip("CUTLASS MLA is not supported on this platform")
|
|
|
|
|
backend = get_attn_backend(
|
|
|
|
|
576, torch.float16, None, block_size, use_mla=use_mla
|
|
|
|
|
)
|
|
|
|
|
expected = "CUTLASS_MLA"
|
|
|
|
|
assert backend.get_name() == expected
|
2025-09-12 18:30:07 -04:00
|
|
|
elif name == "FLASHINFER_MLA":
|
2025-11-11 06:40:44 -06:00
|
|
|
if capability[0] != 10:
|
|
|
|
|
pytest.skip(
|
|
|
|
|
"FlashInfer MLA is not supported on this platform"
|
|
|
|
|
)
|
2025-09-20 17:56:30 -07:00
|
|
|
if block_size not in [32, 64]:
|
2025-09-12 18:30:07 -04:00
|
|
|
# FlashInfer MLA only supports block_size 32 or 64
|
|
|
|
|
pytest.skip(
|
|
|
|
|
"FlashInfer MLA only supports block_size 32 or 64"
|
|
|
|
|
)
|
2025-11-11 06:40:44 -06:00
|
|
|
backend = get_attn_backend(
|
|
|
|
|
576, torch.float16, None, block_size, use_mla=use_mla
|
|
|
|
|
)
|
|
|
|
|
expected = "FLASHINFER_MLA"
|
|
|
|
|
assert backend.get_name() == expected
|
2025-09-04 05:47:59 -04:00
|
|
|
elif name == "FLASHMLA":
|
|
|
|
|
if block_size != 64:
|
|
|
|
|
# FlashMLA only supports block_size == 64
|
|
|
|
|
pytest.skip("FlashMLA only supports block_size 64")
|
2025-11-11 06:40:44 -06:00
|
|
|
from vllm.v1.attention.backends.mla.flashmla import (
|
|
|
|
|
is_flashmla_dense_supported,
|
|
|
|
|
)
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2025-11-11 06:40:44 -06:00
|
|
|
is_supported, _ = is_flashmla_dense_supported()
|
|
|
|
|
if not is_supported:
|
|
|
|
|
pytest.skip("FlashMLA not supported on this platform")
|
|
|
|
|
backend = get_attn_backend(
|
|
|
|
|
576,
|
|
|
|
|
torch.float16,
|
|
|
|
|
None,
|
|
|
|
|
block_size,
|
|
|
|
|
use_mla=use_mla,
|
|
|
|
|
)
|
|
|
|
|
expected = name
|
|
|
|
|
assert backend.get_name() == expected
|
2025-09-04 05:47:59 -04:00
|
|
|
elif name == "FLASH_ATTN_MLA":
|
2026-01-09 16:10:24 -05:00
|
|
|
from vllm.v1.attention.backends.fa_utils import (
|
2025-11-11 06:40:44 -06:00
|
|
|
flash_attn_supports_mla,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not flash_attn_supports_mla():
|
|
|
|
|
pytest.skip(
|
|
|
|
|
"FlashAttention MLA not supported on this platform"
|
|
|
|
|
)
|
2025-09-20 17:56:30 -07:00
|
|
|
backend = get_attn_backend(
|
2025-11-11 06:40:44 -06:00
|
|
|
576, torch.float16, None, block_size, use_mla=use_mla
|
2025-09-20 17:56:30 -07:00
|
|
|
)
|
|
|
|
|
expected = "FLASH_ATTN_MLA"
|
|
|
|
|
assert backend.get_name() == expected
|
2025-04-23 00:31:13 +08:00
|
|
|
else:
|
2025-09-04 05:47:59 -04:00
|
|
|
# TRITON_MLA or other fallback
|
2025-04-23 00:31:13 +08:00
|
|
|
backend = get_attn_backend(
|
2025-11-11 06:40:44 -06:00
|
|
|
576, torch.float16, None, block_size, use_mla=use_mla
|
2025-04-23 00:31:13 +08:00
|
|
|
)
|
2025-09-25 13:37:50 -04:00
|
|
|
expected = "TRITON_MLA"
|
2025-04-23 00:31:13 +08:00
|
|
|
assert backend.get_name() == expected
|
2025-04-23 07:21:07 -06:00
|
|
|
elif name == "FLASHINFER":
|
|
|
|
|
backend = get_attn_backend(
|
2025-11-11 06:40:44 -06:00
|
|
|
64, torch.float16, None, block_size, use_mla=use_mla
|
2025-04-23 07:21:07 -06:00
|
|
|
)
|
2025-09-25 13:37:50 -04:00
|
|
|
expected = "FLASHINFER"
|
2025-04-23 07:21:07 -06:00
|
|
|
assert backend.get_name() == expected
|
2025-09-25 13:37:50 -04:00
|
|
|
elif name == "FLASH_ATTN":
|
|
|
|
|
backend = get_attn_backend(
|
2025-09-20 17:56:30 -07:00
|
|
|
32, torch.float16, None, block_size, use_mla=use_mla
|
|
|
|
|
)
|
2025-09-25 13:37:50 -04:00
|
|
|
expected = "FLASH_ATTN"
|
|
|
|
|
assert backend.get_name() == expected
|
2025-07-07 00:54:36 +08:00
|
|
|
|
2024-05-22 15:55:56 -07:00
|
|
|
|
2025-07-05 17:41:10 +08:00
|
|
|
@pytest.mark.parametrize("device", ["cpu", "cuda"])
|
2025-10-07 23:42:31 +08:00
|
|
|
def test_fp32_fallback(device: str):
|
2025-07-05 17:41:10 +08:00
|
|
|
"""Test attention backend selection with fp32."""
|
2025-12-17 12:49:59 -05:00
|
|
|
# Use default config (no backend specified)
|
|
|
|
|
vllm_config = VllmConfig()
|
2025-07-05 17:41:10 +08:00
|
|
|
|
2025-12-17 12:49:59 -05:00
|
|
|
with set_current_vllm_config(vllm_config):
|
|
|
|
|
if device == "cpu":
|
|
|
|
|
with patch("vllm.platforms.current_platform", CpuPlatform()):
|
|
|
|
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
|
|
|
|
assert backend.get_name() == "CPU_ATTN"
|
|
|
|
|
|
|
|
|
|
elif device == "cuda":
|
|
|
|
|
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
|
|
|
|
backend = get_attn_backend(16, torch.float32, None, 16)
|
|
|
|
|
assert backend.get_name() == "FLEX_ATTENTION"
|
2025-07-05 17:41:10 +08:00
|
|
|
|
|
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
2024-05-22 15:55:56 -07:00
|
|
|
"""Test FlashAttn validation."""
|
2025-09-25 13:37:50 -04:00
|
|
|
pytest.skip(
|
|
|
|
|
"Skipping as current backend selector does not "
|
2025-12-17 12:49:59 -05:00
|
|
|
"handle fallbacks when a backend is explicitly set."
|
2025-09-25 13:37:50 -04:00
|
|
|
)
|
|
|
|
|
|
2025-12-17 12:49:59 -05:00
|
|
|
attention_config = AttentionConfig(backend=AttentionBackendEnum.FLASH_ATTN)
|
|
|
|
|
vllm_config = VllmConfig(attention_config=attention_config)
|
2024-05-22 15:55:56 -07:00
|
|
|
|
2025-12-17 12:49:59 -05:00
|
|
|
with set_current_vllm_config(vllm_config):
|
2025-03-17 11:35:57 +08:00
|
|
|
# Unsupported CUDA arch
|
2025-05-09 19:46:54 -07:00
|
|
|
monkeypatch.setattr(torch.cuda, "get_device_capability", lambda _=None: (7, 5))
|
2025-09-24 00:12:14 +02:00
|
|
|
backend = get_attn_backend(16, torch.float16, None, 16)
|
2025-11-28 20:35:19 +08:00
|
|
|
assert backend.get_name() != "FLASH_ATTN"
|
2024-05-22 15:55:56 -07:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
# Reset the monkeypatch for subsequent tests
|
|
|
|
|
monkeypatch.undo()
|
2024-05-22 15:55:56 -07:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
# Unsupported data type
|
2025-09-24 00:12:14 +02:00
|
|
|
backend = get_attn_backend(16, torch.float8_e4m3fn, None, 16)
|
2025-11-28 20:35:19 +08:00
|
|
|
assert backend.get_name() != "FLASH_ATTN"
|
2024-05-22 15:55:56 -07:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
# Unsupported kv cache data type
|
2025-09-24 00:12:14 +02:00
|
|
|
backend = get_attn_backend(16, torch.float16, "fp8", 16)
|
2025-11-28 20:35:19 +08:00
|
|
|
assert backend.get_name() != "FLASH_ATTN"
|
2024-05-22 15:55:56 -07:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
# Unsupported block size
|
2025-09-24 00:12:14 +02:00
|
|
|
backend = get_attn_backend(16, torch.float16, None, 8)
|
2025-11-28 20:35:19 +08:00
|
|
|
assert backend.get_name() != "FLASH_ATTN"
|
2025-03-17 11:35:57 +08:00
|
|
|
|
|
|
|
|
# flash-attn is not installed
|
|
|
|
|
import sys
|
2025-10-05 15:06:22 +01:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
original_module = sys.modules.get("vllm_flash_attn")
|
|
|
|
|
monkeypatch.setitem(sys.modules, "vllm_flash_attn", None)
|
2025-09-24 00:12:14 +02:00
|
|
|
backend = get_attn_backend(16, torch.float16, None, 16)
|
2025-11-28 20:35:19 +08:00
|
|
|
assert backend.get_name() != "FLASH_ATTN"
|
2024-05-22 15:55:56 -07:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
# Restore the original module if it existed
|
|
|
|
|
if original_module is not None:
|
|
|
|
|
monkeypatch.setitem(sys.modules, "vllm_flash_attn", original_module)
|
|
|
|
|
else:
|
|
|
|
|
monkeypatch.delitem(sys.modules, "vllm_flash_attn", raising=False)
|
2024-10-11 11:40:06 -04:00
|
|
|
|
2025-03-17 11:35:57 +08:00
|
|
|
# Unsupported head size
|
2025-09-24 00:12:14 +02:00
|
|
|
backend = get_attn_backend(17, torch.float16, None, 16)
|
2025-11-28 20:35:19 +08:00
|
|
|
assert backend.get_name() != "FLASH_ATTN"
|
2024-05-22 15:55:56 -07:00
|
|
|
|
|
|
|
|
|
2025-12-17 12:49:59 -05:00
|
|
|
def test_invalid_backend():
|
2025-08-05 02:54:52 -04:00
|
|
|
"""Test that invalid attention backend names raise ValueError."""
|
2025-03-17 11:35:57 +08:00
|
|
|
with (
|
2025-12-17 12:49:59 -05:00
|
|
|
pytest.raises(ValueError),
|
2025-03-17 11:35:57 +08:00
|
|
|
):
|
2025-12-17 12:49:59 -05:00
|
|
|
# Invalid backend name should raise ValueError when creating enum
|
|
|
|
|
AttentionConfig(backend=AttentionBackendEnum["INVALID"])
|