[Attention] FlashAttn MLA (#14258)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -22,7 +22,7 @@ def clear_cache():
|
||||
|
||||
# Define MLA and non-MLA backends separately
|
||||
DEVICE_MLA_BACKENDS = {
|
||||
"cuda": ["TRITON_MLA", "FLASHMLA"],
|
||||
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
|
||||
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
|
||||
"cpu": [],
|
||||
}
|
||||
@@ -98,21 +98,14 @@ def test_env(
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
RocmPlatform()):
|
||||
if use_mla:
|
||||
# Validate HIP MLA backend-block_size combinations
|
||||
valid_combination = (
|
||||
(name == "TRITON_MLA" and block_size != 1)
|
||||
or (name == "ROCM_AITER_MLA" and block_size == 1))
|
||||
# ROCm MLA backend logic:
|
||||
# - TRITON_MLA: supported when block_size != 1
|
||||
# - ROCM_AITER_MLA: supported when block_size == 1
|
||||
# If backend is forced but doesn't match block_size,
|
||||
# should raise ValueError
|
||||
|
||||
if valid_combination:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
if name == "TRITON_MLA" and block_size == 1:
|
||||
# TRITON_MLA doesn't support block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
@@ -122,6 +115,27 @@ def test_env(
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
elif name == "ROCM_AITER_MLA" and block_size != 1:
|
||||
# ROCM_AITER_MLA only supports block_size == 1
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
assert f"The selected backend, {name}" in str(
|
||||
exc_info.value)
|
||||
else:
|
||||
# Valid backend-block_size combination
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
@@ -136,16 +150,22 @@ def test_env(
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CudaPlatform()):
|
||||
if use_mla:
|
||||
if name == "FLASHMLA" and block_size == 64:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
# and Blackwell GPUs (SM 10.0), V1 only
|
||||
# - FLASHMLA: only supported with block_size == 64
|
||||
# - FLASH_ATTN_MLA: V1 only
|
||||
# - TRITON_MLA: fallback for other cases
|
||||
|
||||
# only on cuda platforms with specific capability.
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
|
||||
if not is_supported:
|
||||
# if platform is not supported then skip this case.
|
||||
pytest.skip()
|
||||
if name == "CUTLASS_MLA":
|
||||
if not use_v1:
|
||||
# CUTLASS_MLA only supported on V1 engine
|
||||
pytest.skip(
|
||||
"CUTLASS_MLA only supported on V1 engine")
|
||||
elif block_size != 128:
|
||||
# CUTLASS_MLA only supports block_size == 128
|
||||
pytest.skip(
|
||||
"CUTLASS_MLA only supports block_size 128")
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
@@ -153,9 +173,45 @@ def test_env(
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
expected = "CUTLASS_MLA_VLLM_V1"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
if block_size != 64:
|
||||
# FlashMLA only supports block_size == 64
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.attention.backends.flashmla import (
|
||||
is_flashmla_supported)
|
||||
is_supported, _ = is_flashmla_supported()
|
||||
if not is_supported:
|
||||
pytest.skip(
|
||||
"FlashMLA not supported on this platform")
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = f"{name}_VLLM_V1" if use_v1 else name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
if not use_v1:
|
||||
# FlashAttention MLA only supported on V1 engine
|
||||
pytest.skip(
|
||||
"FlashAttention MLA only supported on V1 engine"
|
||||
)
|
||||
else:
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
block_size,
|
||||
False,
|
||||
use_mla=use_mla)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(16,
|
||||
torch.float16,
|
||||
torch.float16,
|
||||
|
||||
Reference in New Issue
Block a user