[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -11,7 +11,7 @@ from typing import Any, NamedTuple
|
||||
import pytest
|
||||
import regex as re
|
||||
|
||||
from tests.v1.attention.utils import _Backend
|
||||
from tests.v1.attention.utils import AttentionBackendEnum
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import CompilationConfig, CompilationMode, CUDAGraphMode, PassConfig
|
||||
from vllm.platforms import current_platform
|
||||
@@ -24,7 +24,7 @@ from ..utils import flat_product, multi_gpu_test
|
||||
class ModelBackendTestCase(NamedTuple):
|
||||
model_name: str
|
||||
model_kwargs: dict[str, Any]
|
||||
backend: _Backend
|
||||
backend: AttentionBackendEnum
|
||||
attention_fusions: int
|
||||
allreduce_fusions: int | None = None
|
||||
|
||||
@@ -39,14 +39,14 @@ if current_platform.is_cuda():
|
||||
# Use smaller model for L40s in CI
|
||||
model_name="RedHatAI/Meta-Llama-3.1-8B-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=48,
|
||||
allreduce_fusions=96,
|
||||
),
|
||||
@@ -56,7 +56,7 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="nvidia/Llama-3.1-8B-Instruct-FP4",
|
||||
model_kwargs=dict(max_model_len=1024, kv_cache_dtype="fp8"),
|
||||
backend=_Backend.FLASHINFER,
|
||||
backend=AttentionBackendEnum.FLASHINFER,
|
||||
attention_fusions=32,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
@@ -67,7 +67,7 @@ if current_platform.is_cuda():
|
||||
ModelBackendTestCase(
|
||||
model_name="meta-llama/Llama-3.1-8B-Instruct",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=0,
|
||||
allreduce_fusions=65,
|
||||
),
|
||||
@@ -85,19 +85,19 @@ elif current_platform.is_rocm():
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.TRITON_ATTN,
|
||||
backend=AttentionBackendEnum.TRITON_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_ATTN,
|
||||
backend=AttentionBackendEnum.ROCM_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
ModelBackendTestCase(
|
||||
model_name="amd/Llama-3.1-8B-Instruct-FP8-KV",
|
||||
model_kwargs=dict(max_model_len=1024),
|
||||
backend=_Backend.ROCM_AITER_UNIFIED_ATTN,
|
||||
backend=AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN,
|
||||
attention_fusions=32,
|
||||
),
|
||||
]
|
||||
@@ -117,7 +117,7 @@ CUSTOM_OPS_FP8 = ["-quant_fp8", "+quant_fp8"]
|
||||
def test_attn_quant(
|
||||
model_name: str,
|
||||
model_kwargs: dict[str, Any],
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
@@ -125,7 +125,7 @@ def test_attn_quant(
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
):
|
||||
if backend == _Backend.FLASHINFER and (
|
||||
if backend == AttentionBackendEnum.FLASHINFER and (
|
||||
not current_platform.is_device_capability((10, 0)) or not has_flashinfer()
|
||||
):
|
||||
pytest.skip("FlashInfer attn fusion requires Blackwell and flashinfer")
|
||||
@@ -208,7 +208,7 @@ def custom_ops_product(*custom_ops_lists: list[str]) -> Iterable[str]:
|
||||
def test_tp2_attn_quant_allreduce_rmsnorm(
|
||||
model_name: str,
|
||||
model_kwargs: dict,
|
||||
backend: _Backend,
|
||||
backend: AttentionBackendEnum,
|
||||
attention_fusions: int,
|
||||
allreduce_fusions: int,
|
||||
custom_ops: str,
|
||||
|
||||
Reference in New Issue
Block a user