[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -11,7 +11,7 @@ from typing import Any, NamedTuple
import pytest
import regex as re
from tests.v1.attention.utils import _Backend
from tests.v1.attention.utils import AttentionBackendEnum
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
from vllm.platforms import current_platform
@@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test
class ModelBackendTestCase(NamedTuple):
model_name: str
model_kwargs: dict[str, Any]
backend: _Backend
backend: AttentionBackendEnum
attention_fusions: int
allreduce_fusions: int | None = None
@@ -39,14 +39,14 @@ if current_platform.is_cuda():
# Use smaller model for L40s in CI
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
allreduce_fusions=65,
),
ModelBackendTestCase(
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=48,
allreduce_fusions=96,
),
@@ -56,7 +56,7 @@ if current_platform.is_cuda():
ModelBackendTestCase(
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
backend=_Backend.FLASHINFER,
backend=AttentionBackendEnum.FLASHINFER,
attention_fusions=32,
allreduce_fusions=65,
),
@@ -67,7 +67,7 @@ if current_platform.is_cuda():
ModelBackendTestCase(
model_name="meta-llama/Llama-3.1-8B-Instruct",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=0,
allreduce_fusions=65,
),
@@ -85,19 +85,19 @@ elif current_platform.is_rocm():
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.TRITON_ATTN,
backend=AttentionBackendEnum.TRITON_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_ATTN,
backend=AttentionBackendEnum.ROCM_ATTN,
attention_fusions=32,
),
ModelBackendTestCase(
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
model_kwargs=dict(max_model_len=1024),
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
attention_fusions=32,
),
]
@@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
def test_attn_quant(
model_name: str,
model_kwargs: dict[str, Any],
backend: _Backend,
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,
@@ -125,7 +125,7 @@ def test_attn_quant(
caplog_mp_spawn,
monkeypatch,
):
if backend == _Backend.FLASHINFER and (
if backend == AttentionBackendEnum.FLASHINFER and (
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
):
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
@@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
def test_tp2_attn_quant_allreduce_rmsnorm(
model_name: str,
model_kwargs: dict,
backend: _Backend,
backend: AttentionBackendEnum,
attention_fusions: int,
allreduce_fusions: int,
custom_ops: str,