[FEAT][ROCm]: Support AITER MLA (#15893)

Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com>
Co-authored-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
vllmellm
2025-04-23 00:31:13 +08:00
committed by GitHub
parent f34410715f
commit 30bc3e0f66
9 changed files with 668 additions and 30 deletions

View File

@@ -19,45 +19,152 @@ def clear_cache():
_cached_get_attn_backend.cache_clear()
@pytest.mark.parametrize(
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA"],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}
DEVICE_REGULAR_ATTN_BACKENDS = {
"cuda": ["XFORMERS", "FLASHINFER"],
"hip": ["ROCM_FLASH"],
"cpu": ["TORCH_SDPA"],
}
DEVICE_MLA_BLOCK_SIZES = {
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
"hip": [16, 1], # HIP requires special handling for block_size=1
"cpu": [16] # CPU uses fixed block size from test cases
}
def generate_params():
params = []
for use_mla in [True, False]:
for device in ["cuda", "hip", "cpu"]:
backends = DEVICE_MLA_BACKENDS[
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
for name in backends:
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
16
]
for block_size in block_sizes:
params.append(
pytest.param(
device,
name,
use_mla,
block_size,
id=
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
))
return params
@pytest.mark.parametrize("device, name, use_mla, block_size",
generate_params())
@pytest.mark.parametrize("use_v1", [True, False])
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
def test_env(
name: str,
use_v1: bool,
device: str,
name: str,
use_mla: bool,
block_size: int,
use_v1: bool,
monkeypatch: pytest.MonkeyPatch,
):
"""Test that the attention selector can be set via environment variable.
Note that we do not test FlashAttn because it is the default backend.
"""
"""Test attention backend selection with valid device-backend pairs."""
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
m.setenv(STR_BACKEND_ENV_VAR, name)
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
block_size, False)
assert backend.get_name() == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
RocmPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
16, False)
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == EXPECTED
else:
if name in ["XFORMERS", "FLASHINFER"]:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float16,
torch.float16, 16, False)
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == EXPECTED
if use_mla:
# Validate HIP MLA backend-block_size combinations
valid_combination = (
(name == "TRITON_MLA" and block_size != 1)
or (name == "ROCM_AITER_MLA" and block_size == 1))
if valid_combination:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert backend.get_name() == name
else:
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
assert backend.get_name() == expected
elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
if use_mla:
if name == "FLASHMLA" and block_size == 64:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
# only on cuda platforms with specific capability.
is_supported, _ = is_flashmla_supported()
if not is_supported:
# if platform is not supported then skip this case.
pytest.skip()
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = ("TRITON_MLA_VLLM_V1"
if use_v1 else "TRITON_MLA")
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):