[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -120,12 +120,13 @@ def test_env(
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.platforms.current_platform", CudaPlatform()):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if use_mla:
|
||||
# CUDA MLA backend logic:
|
||||
# - CUTLASS_MLA: only supported with block_size == 128
|
||||
# and Blackwell GPUs (SM 10.0), V1 only
|
||||
# and Blackwell GPUs (SM 10.x), V1 only
|
||||
# - FLASHINFER_MLA: only supported on Blackwell GPUs
|
||||
# (SM 10.0+), V1 only
|
||||
# (SM 10.x), V1 only
|
||||
# - FLASHMLA: only supported with block_size == 64
|
||||
# - FLASH_ATTN_MLA: V1 only
|
||||
# - TRITON_MLA: fallback for other cases
|
||||
@@ -134,58 +135,72 @@ def test_env(
|
||||
if block_size != 128:
|
||||
# CUTLASS_MLA only supports block_size == 128
|
||||
pytest.skip("CUTLASS_MLA only supports block_size 128")
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
if capability[0] != 10:
|
||||
pytest.skip("CUTLASS MLA is not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "CUTLASS_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER_MLA":
|
||||
if capability[0] != 10:
|
||||
pytest.skip(
|
||||
"FlashInfer MLA is not supported on this platform"
|
||||
)
|
||||
if block_size not in [32, 64]:
|
||||
# FlashInfer MLA only supports block_size 32 or 64
|
||||
pytest.skip(
|
||||
"FlashInfer MLA only supports block_size 32 or 64"
|
||||
)
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHMLA":
|
||||
if block_size != 64:
|
||||
# FlashMLA only supports block_size == 64
|
||||
pytest.skip("FlashMLA only supports block_size 64")
|
||||
else:
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.flashmla import (
|
||||
is_flashmla_dense_supported,
|
||||
)
|
||||
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
else:
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
is_supported, _ = is_flashmla_dense_supported()
|
||||
if not is_supported:
|
||||
pytest.skip("FlashMLA not supported on this platform")
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576,
|
||||
torch.float16,
|
||||
None,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
)
|
||||
expected = name
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASH_ATTN_MLA":
|
||||
from vllm.attention.utils.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
)
|
||||
|
||||
if not flash_attn_supports_mla():
|
||||
pytest.skip(
|
||||
"FlashAttention MLA not supported on this platform"
|
||||
)
|
||||
backend = get_attn_backend(
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASH_ATTN_MLA"
|
||||
assert backend.get_name() == expected
|
||||
else:
|
||||
# TRITON_MLA or other fallback
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
576, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "TRITON_MLA"
|
||||
assert backend.get_name() == expected
|
||||
elif name == "FLASHINFER":
|
||||
backend = get_attn_backend(
|
||||
16, torch.float16, None, block_size, use_mla=use_mla
|
||||
64, torch.float16, None, block_size, use_mla=use_mla
|
||||
)
|
||||
expected = "FLASHINFER"
|
||||
assert backend.get_name() == expected
|
||||
|
||||
Reference in New Issue
Block a user