[Attention] FlashAttn MLA (#14258)

Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Lucas Wilkinson
2025-09-04 05:47:59 -04:00
committed by GitHub
parent 2c301ee2eb
commit 402759d472
22 changed files with 480 additions and 200 deletions

View File

@@ -22,7 +22,7 @@ def clear_cache():
# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA"],
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}
@@ -98,21 +98,14 @@ def test_env(
with patch("vllm.attention.selector.current_platform",
RocmPlatform()):
if use_mla:
# Validate HIP MLA backend-block_size combinations
valid_combination = (
(name == "TRITON_MLA" and block_size != 1)
or (name == "ROCM_AITER_MLA" and block_size == 1))
# ROCm MLA backend logic:
# - TRITON_MLA: supported when block_size != 1
# - ROCM_AITER_MLA: supported when block_size == 1
# If backend is forced but doesn't match block_size,
# should raise ValueError
if valid_combination:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
@@ -122,6 +115,27 @@ def test_env(
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
elif name == "ROCM_AITER_MLA" and block_size != 1:
# ROCM_AITER_MLA only supports block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
else:
# Valid backend-block_size combination
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
@@ -136,16 +150,22 @@ def test_env(
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
if use_mla:
if name == "FLASHMLA" and block_size == 64:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
# only on cuda platforms with specific capability.
is_supported, _ = is_flashmla_supported()
if not is_supported:
# if platform is not supported then skip this case.
pytest.skip()
if name == "CUTLASS_MLA":
if not use_v1:
# CUTLASS_MLA only supported on V1 engine
pytest.skip(
"CUTLASS_MLA only supported on V1 engine")
elif block_size != 128:
# CUTLASS_MLA only supports block_size == 128
pytest.skip(
"CUTLASS_MLA only supports block_size 128")
else:
backend = get_attn_backend(16,
torch.float16,
@@ -153,9 +173,45 @@ def test_env(
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
is_supported, _ = is_flashmla_supported()
if not is_supported:
pytest.skip(
"FlashMLA not supported on this platform")
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
if not use_v1:
# FlashAttention MLA only supported on V1 engine
pytest.skip(
"FlashAttention MLA only supported on V1 engine"
)
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(16,
torch.float16,
torch.float16,