[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -23,10 +23,10 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
@@ -127,7 +127,7 @@ class CpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
@@ -137,9 +137,9 @@ class CpuPlatform(Platform):
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
@@ -148,7 +148,7 @@ class CpuPlatform(Platform):
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
return AttentionBackendEnum.TORCH_SDPA.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
|
||||
@@ -22,10 +22,13 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
else:
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
CacheDType = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -39,6 +42,49 @@ pynvml = import_pynvml()
|
||||
torch.backends.cuda.enable_cudnn_sdp(False)
|
||||
|
||||
|
||||
@cache
|
||||
def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
device_capability: DeviceCapability,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
AttentionBackendEnum.CUTLASS_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.FLASHMLA,
|
||||
AttentionBackendEnum.FLASH_ATTN_MLA,
|
||||
AttentionBackendEnum.FLASHINFER_MLA,
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
AttentionBackendEnum.FLASHMLA_SPARSE,
|
||||
]
|
||||
else:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
]
|
||||
else:
|
||||
return [
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.FLASHINFER,
|
||||
AttentionBackendEnum.TRITON_ATTN,
|
||||
AttentionBackendEnum.FLEX_ATTENTION,
|
||||
]
|
||||
|
||||
|
||||
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||
@wraps(fn)
|
||||
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
|
||||
@@ -216,217 +262,171 @@ class CudaPlatformBase(Platform):
|
||||
return torch.cuda.max_memory_allocated(device)
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
# For Blackwell GPUs, force TORCH_SDPA for now.
|
||||
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
|
||||
if cls.has_device_capability(100):
|
||||
return _Backend.TORCH_SDPA
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
if dtype not in (torch.float16, torch.bfloat16):
|
||||
return _Backend.XFORMERS
|
||||
return AttentionBackendEnum.XFORMERS
|
||||
|
||||
if cls.has_device_capability(80):
|
||||
FLASH_ATTN_V1 = (
|
||||
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
)
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
is_default_fa_supported = is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
|
||||
)
|
||||
if is_default_fa_supported:
|
||||
return _Backend.FLASH_ATTN
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
if backend_class.supports_head_size(
|
||||
head_size
|
||||
) and backend_class.supports_dtype(dtype):
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
else:
|
||||
# Fallback to XFORMERS
|
||||
return _Backend.XFORMERS
|
||||
return AttentionBackendEnum.XFORMERS
|
||||
else:
|
||||
# Fallback for Volta/Turing GPUs or FA not supported
|
||||
return _Backend.XFORMERS
|
||||
return AttentionBackendEnum.XFORMERS
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
def get_valid_backends(
|
||||
cls,
|
||||
selected_backend,
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
) -> tuple[
|
||||
list[tuple["AttentionBackendEnum", int]],
|
||||
dict["AttentionBackendEnum", list[str]],
|
||||
]:
|
||||
valid_backends_priorities = []
|
||||
invalid_reasons = {}
|
||||
|
||||
backend_priorities = _get_backend_priorities(use_mla, device_capability)
|
||||
for priority, backend in enumerate(backend_priorities):
|
||||
try:
|
||||
backend_class = backend.get_class()
|
||||
invalid_reasons_i = backend_class.validate_configuration(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons_i = ["ImportError"]
|
||||
if invalid_reasons_i:
|
||||
invalid_reasons[backend] = invalid_reasons_i
|
||||
else:
|
||||
valid_backends_priorities.append((backend, priority))
|
||||
|
||||
return valid_backends_priorities, invalid_reasons
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int | None,
|
||||
use_v1: bool,
|
||||
use_mla: bool,
|
||||
has_sink: bool,
|
||||
use_sparse: bool,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
if not use_v1:
|
||||
raise RuntimeError(
|
||||
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
|
||||
"to select a supported backend."
|
||||
)
|
||||
|
||||
if use_mla:
|
||||
# explicitly reject non-MLA backends when MLA is enabled to avoid
|
||||
# silently selecting an incompatible backend (e.g., FLASHINFER).
|
||||
if selected_backend in {
|
||||
_Backend.FLASHINFER,
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TRITON_ATTN,
|
||||
_Backend.TREE_ATTN,
|
||||
_Backend.XFORMERS,
|
||||
}:
|
||||
device_capability = cls.get_device_capability()
|
||||
assert device_capability is not None
|
||||
|
||||
# First try checking just the selected backend, if there is one.
|
||||
if selected_backend is not None:
|
||||
try:
|
||||
backend_class = selected_backend.get_class()
|
||||
invalid_reasons = backend_class.validate_configuration(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
None,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
except ImportError:
|
||||
invalid_reasons = ["ImportError"]
|
||||
if invalid_reasons:
|
||||
raise ValueError(
|
||||
f"Attention backend {selected_backend} incompatible with MLA. "
|
||||
"Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, "
|
||||
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
|
||||
"VLLM_MLA_DISABLE=1 to disable MLA for this model."
|
||||
f"Selected backend {selected_backend} is not valid for "
|
||||
f"this configuration. Reason: {invalid_reasons}"
|
||||
)
|
||||
else:
|
||||
logger.info("Using %s backend.", selected_backend)
|
||||
return selected_backend.get_path()
|
||||
|
||||
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
|
||||
|
||||
if use_sparse:
|
||||
logger.info_once("Using Sparse MLA backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.flashmla_sparse."
|
||||
"FlashMLASparseBackend"
|
||||
)
|
||||
|
||||
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
|
||||
selected_backend is None
|
||||
and cls.is_device_capability(100)
|
||||
and block_size % 128 == 0
|
||||
)
|
||||
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
|
||||
selected_backend is None
|
||||
and cls.is_device_capability(100)
|
||||
and (block_size == 32 or block_size % 64 == 0)
|
||||
)
|
||||
use_flashmla = selected_backend == _Backend.FLASHMLA or (
|
||||
selected_backend is None and is_flashmla_dense_supported()[0]
|
||||
)
|
||||
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
|
||||
selected_backend is None and flash_attn_supports_mla()
|
||||
)
|
||||
use_triton = selected_backend == _Backend.TRITON_MLA or (
|
||||
selected_backend is None
|
||||
)
|
||||
|
||||
if use_cutlassmla:
|
||||
logger.info_once("Using Cutlass MLA backend.", scope="local")
|
||||
return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
|
||||
if use_flashinfermla:
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout("HND")
|
||||
logger.info_once("Using FlashInfer MLA backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
|
||||
)
|
||||
if use_flashmla:
|
||||
if block_size % 64 != 0:
|
||||
logger.warning(
|
||||
"FlashMLA backend is not supported for block size %d"
|
||||
" (currently only supports block size 64).",
|
||||
block_size,
|
||||
)
|
||||
else:
|
||||
logger.info_once("Using FlashMLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
|
||||
if use_flashattn:
|
||||
logger.info_once("Using FlashAttention MLA backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
|
||||
)
|
||||
if use_triton:
|
||||
logger.info_once("Using Triton MLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
|
||||
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
|
||||
FLEX_ATTENTION_V1 = (
|
||||
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
|
||||
# No selected backend or the selected backend is invalid,
|
||||
# so we try finding a valid backend.
|
||||
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
|
||||
head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
None,
|
||||
use_mla,
|
||||
has_sink,
|
||||
use_sparse,
|
||||
device_capability,
|
||||
)
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
|
||||
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
|
||||
reasons_str = (
|
||||
"{"
|
||||
+ ", ".join(
|
||||
f"{backend.name}: [{', '.join(reasons)}]"
|
||||
for backend, reasons in invalid_reasons.items()
|
||||
)
|
||||
+ "}"
|
||||
)
|
||||
config_str = (
|
||||
f"head_size: {head_size}, dtype: {dtype}, "
|
||||
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
|
||||
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
|
||||
)
|
||||
logger.debug_once(
|
||||
f"Some attention backends are not valid for {cls.device_name} with "
|
||||
f"{config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
if len(valid_backends_priorities) == 0:
|
||||
raise ValueError(
|
||||
f"No valid attention backend found for {cls.device_name} "
|
||||
f"with {config_str}. Reasons: {reasons_str}."
|
||||
)
|
||||
|
||||
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
|
||||
"fp8"
|
||||
# We have found some valid backends. Select the one with the
|
||||
# highest priority.
|
||||
logger.info(
|
||||
"Valid backends: %s", [b[0].name for b in valid_backends_priorities]
|
||||
)
|
||||
sorted_indices = sorted(
|
||||
range(len(valid_backends_priorities)),
|
||||
key=lambda i: valid_backends_priorities[i][1],
|
||||
)
|
||||
selected_index = sorted_indices[0]
|
||||
selected_backend = valid_backends_priorities[selected_index][0]
|
||||
logger.info(
|
||||
"Using %s backend.",
|
||||
selected_backend.name,
|
||||
)
|
||||
|
||||
if selected_backend == _Backend.FLASHINFER:
|
||||
logger.info_once("Using FlashInfer backend.")
|
||||
if cls.has_device_capability(100):
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
set_kv_cache_layout("HND")
|
||||
return FLASHINFER_V1
|
||||
elif selected_backend == _Backend.FLEX_ATTENTION:
|
||||
logger.info_once("Using FlexAttention backend.")
|
||||
return FLEX_ATTENTION_V1
|
||||
elif selected_backend == _Backend.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend.")
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return FLASH_ATTN_V1
|
||||
elif selected_backend == _Backend.TREE_ATTN:
|
||||
logger.info_once("Using Tree Attention backend.")
|
||||
return TREE_ATTN_V1
|
||||
elif selected_backend == _Backend.XFORMERS:
|
||||
logger.info_once("Using XFormers backend.")
|
||||
return XFORMERS_V1
|
||||
|
||||
from vllm.attention.selector import is_attn_backend_supported
|
||||
|
||||
# Default backends for V1 engine
|
||||
# Prefer FlashInfer for Blackwell GPUs if installed
|
||||
if cls.is_device_capability(100):
|
||||
if is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASHINFER_V1, head_size, dtype
|
||||
):
|
||||
from vllm.v1.attention.backends.utils import set_kv_cache_layout
|
||||
|
||||
logger.info_once(
|
||||
"Using FlashInfer backend with HND KV cache layout on "
|
||||
"V1 engine by default for Blackwell (SM 10.0) GPUs."
|
||||
)
|
||||
set_kv_cache_layout("HND")
|
||||
|
||||
return FLASHINFER_V1
|
||||
|
||||
if not is_default_backend_supported.can_import:
|
||||
logger.warning_once(
|
||||
"FlashInfer failed to import on Blackwell (SM 10.0) GPUs; "
|
||||
"it is recommended to install FlashInfer for better "
|
||||
"performance."
|
||||
)
|
||||
|
||||
# FlashAttention is the default for SM 8.0+ GPUs
|
||||
if cls.has_device_capability(80):
|
||||
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
|
||||
logger.info_once("Using Triton backend.")
|
||||
return TRITON_ATTN
|
||||
elif is_default_backend_supported := is_attn_backend_supported(
|
||||
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
|
||||
):
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return FLASH_ATTN_V1
|
||||
|
||||
# FlexAttention is the default for older GPUs
|
||||
else:
|
||||
logger.info_once("Using FlexAttention backend.")
|
||||
return FLEX_ATTENTION_V1
|
||||
|
||||
assert not is_default_backend_supported
|
||||
|
||||
use_flex_attention_reason = {}
|
||||
if not is_default_backend_supported.head_size:
|
||||
use_flex_attention_reason["head_size"] = head_size
|
||||
if not is_default_backend_supported.dtype:
|
||||
use_flex_attention_reason["dtype"] = dtype
|
||||
|
||||
logger.info_once(
|
||||
"Using FlexAttention backend for %s.",
|
||||
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
|
||||
)
|
||||
return FLEX_ATTENTION_V1
|
||||
return selected_backend.get_path()
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
||||
@@ -17,8 +17,9 @@ from vllm.logger import init_logger
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple):
|
||||
major: int
|
||||
minor: int
|
||||
|
||||
def __lt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) < (other.major, other.minor)
|
||||
|
||||
def __le__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) <= (other.major, other.minor)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) == (other.major, other.minor)
|
||||
|
||||
def __ge__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) >= (other.major, other.minor)
|
||||
|
||||
def __gt__(self, other: Any) -> bool:
|
||||
if not isinstance(other, DeviceCapability):
|
||||
return NotImplemented
|
||||
return (self.major, self.minor) > (other.major, other.minor)
|
||||
|
||||
def as_version_str(self) -> str:
|
||||
return f"{self.major}.{self.minor}"
|
||||
|
||||
@@ -173,19 +199,21 @@ class Platform:
|
||||
import vllm._moe_C # noqa: F401
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
# Import _Backend here to avoid circular import.
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
# Import AttentionBackendEnum here to avoid circular import.
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
return _Backend.TORCH_SDPA
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
kv_cache_dtype: "CacheDType | None",
|
||||
block_size: int,
|
||||
use_v1: bool,
|
||||
use_mla: bool,
|
||||
|
||||
@@ -14,10 +14,10 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -204,21 +204,23 @@ class RocmPlatform(Platform):
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> AttentionBackendEnum:
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return _Backend.ROCM_AITER_FA
|
||||
return AttentionBackendEnum.ROCM_AITER_FA
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
return _Backend.FLASH_ATTN
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
return _Backend.TORCH_SDPA
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
@@ -234,7 +236,7 @@ class RocmPlatform(Platform):
|
||||
use_sparse,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
|
||||
@@ -248,55 +250,52 @@ class RocmPlatform(Platform):
|
||||
if use_mla:
|
||||
if selected_backend is None:
|
||||
selected_backend = (
|
||||
_Backend.ROCM_AITER_MLA
|
||||
AttentionBackendEnum.ROCM_AITER_MLA
|
||||
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
|
||||
else _Backend.TRITON_MLA
|
||||
else AttentionBackendEnum.TRITON_MLA
|
||||
)
|
||||
|
||||
if selected_backend == _Backend.TRITON_MLA:
|
||||
if selected_backend == AttentionBackendEnum.TRITON_MLA:
|
||||
if block_size != 1:
|
||||
logger.info_once("Using Triton MLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
|
||||
return AttentionBackendEnum.TRITON_MLA.get_path()
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"does not support block size {block_size}."
|
||||
)
|
||||
if selected_backend == _Backend.ROCM_AITER_MLA:
|
||||
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
|
||||
logger.info("Using AITER MLA backend.")
|
||||
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
|
||||
return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
|
||||
|
||||
raise ValueError(
|
||||
f" The selected backend, {selected_backend.name},"
|
||||
f"is not MLA type while requested for MLA backend."
|
||||
)
|
||||
|
||||
if selected_backend == _Backend.FLEX_ATTENTION:
|
||||
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
|
||||
logger.info("Using FlexAttention backend.")
|
||||
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
|
||||
if (
|
||||
rocm_aiter_ops.is_mha_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_FA:
|
||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
|
||||
logger.info("Using Aiter Flash Attention backend.")
|
||||
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
|
||||
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
|
||||
if (
|
||||
rocm_aiter_ops.is_triton_unified_attn_enabled()
|
||||
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
|
||||
) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
|
||||
logger.info("Using Aiter Unified Attention backend.")
|
||||
return (
|
||||
"vllm.v1.attention.backends."
|
||||
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
|
||||
)
|
||||
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
|
||||
if (
|
||||
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
|
||||
or selected_backend == _Backend.ROCM_ATTN
|
||||
or selected_backend == AttentionBackendEnum.ROCM_ATTN
|
||||
):
|
||||
# rocm specific backend, with aiter and/or
|
||||
# triton prefix-prefill
|
||||
logger.info("Using Rocm Attention backend.")
|
||||
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
|
||||
return AttentionBackendEnum.ROCM_ATTN.get_path()
|
||||
# default case, using triton unified attention
|
||||
logger.info("Using Triton Attention backend.")
|
||||
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
|
||||
@@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import BlockSize
|
||||
from vllm.pooling_params import PoolingParams
|
||||
else:
|
||||
BlockSize = None
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -54,7 +53,7 @@ class TpuPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
@@ -64,17 +63,17 @@ class TpuPlatform(Platform):
|
||||
has_sink,
|
||||
use_sparse,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != _Backend.PALLAS:
|
||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
if not use_v1:
|
||||
raise ValueError("TPU backend only supports V1.")
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
return AttentionBackendEnum.PALLAS.get_path()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
|
||||
@@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
ModelConfig = None
|
||||
VllmConfig = None
|
||||
_Backend = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -44,7 +43,7 @@ class XPUPlatform(Platform):
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
cls,
|
||||
selected_backend: "_Backend",
|
||||
selected_backend: "AttentionBackendEnum",
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: str | None,
|
||||
@@ -62,18 +61,19 @@ class XPUPlatform(Platform):
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
|
||||
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
|
||||
if selected_backend == _Backend.TRITON_ATTN:
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
if not use_v1:
|
||||
raise ValueError("XPU backend only supports V1.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
logger.info_once("Using Triton backend.")
|
||||
return TRITON_ATTN
|
||||
elif selected_backend == _Backend.FLASH_ATTN:
|
||||
return AttentionBackendEnum.TRITON_ATTN.get_path()
|
||||
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
|
||||
logger.info_once("Using Flash Attention backend.")
|
||||
return FLASH_ATTN
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
elif selected_backend:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {cls.device_name}, "
|
||||
@@ -81,7 +81,7 @@ class XPUPlatform(Platform):
|
||||
)
|
||||
|
||||
logger.info("Using Flash Attention backend.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
return AttentionBackendEnum.FLASH_ATTN.get_path()
|
||||
|
||||
@classmethod
|
||||
def set_device(cls, device: torch.device) -> None:
|
||||
@@ -113,10 +113,10 @@ class XPUPlatform(Platform):
|
||||
return device_props.total_memory
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
return _Backend.FLASH_ATTN
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> AttentionBackendEnum:
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
def inference_mode(cls):
|
||||
|
||||
Reference in New Issue
Block a user