[FEAT][ROCm]: Support AITER MLA (#15893)
Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: qli88 <qiang.li2@amd.com>
This commit is contained in:
@@ -19,45 +19,152 @@ def clear_cache():
|
||||
_cached_get_attn_backend.cache_clear()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name", ["TORCH_SDPA", "ROCM_FLASH", "XFORMERS", "FLASHINFER"])
|
||||
# Define MLA and non-MLA backends separately
|
||||
DEVICE_MLA_BACKENDS = {
|
||||
"cuda": ["TRITON_MLA", "FLASHMLA"],
|
||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||
"cpu": [],
|
||||
}
|
||||
|
||||
DEVICE_REGULAR_ATTN_BACKENDS = {
|
||||
"cuda": ["XFORMERS", "FLASHINFER"],
|
||||
"hip": ["ROCM_FLASH"],
|
||||
"cpu": ["TORCH_SDPA"],
|
||||
}
|
||||
|
||||
DEVICE_MLA_BLOCK_SIZES = {
|
||||
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
||||
"hip": [16, 1], # HIP requires special handling for block_size=1
|
||||
"cpu": [16] # CPU uses fixed block size from test cases
|
||||
}
|
||||
|
||||
|
||||
def generate_params():
|
||||
params = []
|
||||
for use_mla in [True, False]:
|
||||
for device in ["cuda", "hip", "cpu"]:
|
||||
backends = DEVICE_MLA_BACKENDS[
|
||||
device] if use_mla else DEVICE_REGULAR_ATTN_BACKENDS[device]
|
||||
for name in backends:
|
||||
block_sizes = DEVICE_MLA_BLOCK_SIZES[device] if use_mla else [
|
||||
16
|
||||
]
|
||||
for block_size in block_sizes:
|
||||
params.append(
|
||||
pytest.param(
|
||||
device,
|
||||
name,
|
||||
use_mla,
|
||||
block_size,
|
||||
id=
|
||||
f"{device}_{name}_mla_{str(use_mla)[0]}_blks{block_size}"
|
||||
))
|
||||
return params
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device, name, use_mla, block_size",
|
||||
generate_params())
|
||||
@pytest.mark.parametrize("use_v1", [True, False])
|
||||
@pytest.mark.parametrize("device", ["cpu", "hip", "cuda"])
|
||||
def test_env(
|
||||
name: str,
|
||||
use_v1: bool,
|
||||
device: str,
|
||||
name: str,
|
||||
use_mla: bool,
|
||||
block_size: int,
|
||||
use_v1: bool,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that the attention selector can be set via environment variable.
|
||||
Note that we do not test FlashAttn because it is the default backend.
|
||||
"""
|
||||
|
||||
"""Test attention backend selection with valid device-backend pairs."""
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
m.setenv(STR_BACKEND_ENV_VAR, name)
|
||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
|
||||
if device == "cpu":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
block_size, False)
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
16, False)
|
||||
EXPECTED = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||
assert backend.get_name() == EXPECTED
|
||||
else:
|
||||
if name in ["XFORMERS", "FLASHINFER"]:
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16,
|
||||
torch.float16, 16, False)
|
||||
EXPECTED = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == EXPECTED
|
||||
if use_mla:
|
||||
# Validate HIP MLA backend-block_size combinations
|
||||
valid_combination = (
|
||||
(name == "TRITON_MLA" and block_size != 1)
|
||||
or (name == "ROCM_AITER_MLA" and block_size == 1))
|
||||
|
||||
if valid_combination:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert backend.get_name() == name
|
||||
else:
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "TRITON_ATTN_VLLM_V1" if use_v1 else "ROCM_FLASH"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
if use_mla:
|
||||
if name == "FLASHMLA" and block_size == 64:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
|
||||
# only on cuda platforms with specific capability.
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
|
||||
if not is_supported:
|
||||
# if platform is not supported then skip this case.
|
||||
pytest.skip()
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = ("TRITON_MLA_VLLM_V1"
|
||||
if use_v1 else "TRITON_MLA")
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
|
||||
|
||||
def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
Reference in New Issue
Block a user