[Attention] Refactor CUDA attention backend selection logic (#24794)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2025-11-11 06:40:44 -06:00
committed by GitHub
parent 2e78150d24
commit b30dfa03c5
61 changed files with 1338 additions and 1002 deletions

View File

@@ -23,10 +23,10 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
_Backend = None
AttentionBackendEnum = None
VllmConfig = None
@@ -127,7 +127,7 @@ class CpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "_Backend",
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
@@ -137,9 +137,9 @@ class CpuPlatform(Platform):
has_sink: bool,
use_sparse: bool,
) -> str:
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
if selected_backend and selected_backend != AttentionBackendEnum.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:
raise NotImplementedError("MLA is not supported on CPU.")
@@ -148,7 +148,7 @@ class CpuPlatform(Platform):
logger.info("Using Torch SDPA backend.")
if not use_v1:
raise ValueError("CPU backend only supports V1.")
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
return AttentionBackendEnum.TORCH_SDPA.get_path()
@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:

View File

@@ -22,10 +22,13 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
_Backend = None
AttentionBackendEnum = None
VllmConfig = None
CacheDType = None
logger = init_logger(__name__)
@@ -39,6 +42,49 @@ pynvml = import_pynvml()
torch.backends.cuda.enable_cudnn_sdp(False)
@cache
def _get_backend_priorities(
use_mla: bool,
device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla:
if device_capability.major == 10:
return [
AttentionBackendEnum.CUTLASS_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
return [
AttentionBackendEnum.FLASHMLA,
AttentionBackendEnum.FLASH_ATTN_MLA,
AttentionBackendEnum.FLASHINFER_MLA,
AttentionBackendEnum.TRITON_MLA,
AttentionBackendEnum.FLASHMLA_SPARSE,
]
else:
if device_capability.major == 10:
return [
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
]
else:
return [
AttentionBackendEnum.FLASH_ATTN,
AttentionBackendEnum.FLASHINFER,
AttentionBackendEnum.TRITON_ATTN,
AttentionBackendEnum.FLEX_ATTENTION,
]
def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
@wraps(fn)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
@@ -216,217 +262,171 @@ class CudaPlatformBase(Platform):
return torch.cuda.max_memory_allocated(device)
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
from vllm.attention.backends.registry import _Backend
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# For Blackwell GPUs, force TORCH_SDPA for now.
# See https://github.com/facebookresearch/xformers/issues/1317#issuecomment-3199392579 # noqa: E501
if cls.has_device_capability(100):
return _Backend.TORCH_SDPA
return AttentionBackendEnum.TORCH_SDPA
if dtype not in (torch.float16, torch.bfloat16):
return _Backend.XFORMERS
return AttentionBackendEnum.XFORMERS
if cls.has_device_capability(80):
FLASH_ATTN_V1 = (
"vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
)
from vllm.attention.selector import is_attn_backend_supported
is_default_fa_supported = is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
)
if is_default_fa_supported:
return _Backend.FLASH_ATTN
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
if backend_class.supports_head_size(
head_size
) and backend_class.supports_dtype(dtype):
return AttentionBackendEnum.FLASH_ATTN
else:
# Fallback to XFORMERS
return _Backend.XFORMERS
return AttentionBackendEnum.XFORMERS
else:
# Fallback for Volta/Turing GPUs or FA not supported
return _Backend.XFORMERS
return AttentionBackendEnum.XFORMERS
@classmethod
def get_attn_backend_cls(
def get_valid_backends(
cls,
selected_backend,
head_size,
dtype,
kv_cache_dtype,
block_size,
use_v1,
use_mla,
has_sink,
use_sparse,
device_capability,
) -> tuple[
list[tuple["AttentionBackendEnum", int]],
dict["AttentionBackendEnum", list[str]],
]:
valid_backends_priorities = []
invalid_reasons = {}
backend_priorities = _get_backend_priorities(use_mla, device_capability)
for priority, backend in enumerate(backend_priorities):
try:
backend_class = backend.get_class()
invalid_reasons_i = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla,
has_sink,
use_sparse,
device_capability,
)
except ImportError:
invalid_reasons_i = ["ImportError"]
if invalid_reasons_i:
invalid_reasons[backend] = invalid_reasons_i
else:
valid_backends_priorities.append((backend, priority))
return valid_backends_priorities, invalid_reasons
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: "CacheDType | None",
block_size: int | None,
use_v1: bool,
use_mla: bool,
has_sink: bool,
use_sparse: bool,
) -> str:
from vllm.attention.backends.registry import _Backend
if not use_v1:
raise RuntimeError(
"V0 attention backends have been removed. Set VLLM_USE_V1=1 "
"to select a supported backend."
)
if use_mla:
# explicitly reject non-MLA backends when MLA is enabled to avoid
# silently selecting an incompatible backend (e.g., FLASHINFER).
if selected_backend in {
_Backend.FLASHINFER,
_Backend.FLASH_ATTN,
_Backend.TRITON_ATTN,
_Backend.TREE_ATTN,
_Backend.XFORMERS,
}:
device_capability = cls.get_device_capability()
assert device_capability is not None
# First try checking just the selected backend, if there is one.
if selected_backend is not None:
try:
backend_class = selected_backend.get_class()
invalid_reasons = backend_class.validate_configuration(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
device_capability,
)
except ImportError:
invalid_reasons = ["ImportError"]
if invalid_reasons:
raise ValueError(
f"Attention backend {selected_backend} incompatible with MLA. "
"Please use one of the MLA backends: FLASHINFER_MLA, CUTLASS_MLA, "
"FLASHMLA, FLASH_ATTN_MLA, or TRITON_MLA. Alternatively, set "
"VLLM_MLA_DISABLE=1 to disable MLA for this model."
f"Selected backend {selected_backend} is not valid for "
f"this configuration. Reason: {invalid_reasons}"
)
else:
logger.info("Using %s backend.", selected_backend)
return selected_backend.get_path()
from vllm.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.attention.utils.fa_utils import flash_attn_supports_mla
if use_sparse:
logger.info_once("Using Sparse MLA backend.")
return (
"vllm.v1.attention.backends.mla.flashmla_sparse."
"FlashMLASparseBackend"
)
use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or (
selected_backend is None
and cls.is_device_capability(100)
and block_size % 128 == 0
)
use_flashinfermla = selected_backend == _Backend.FLASHINFER_MLA or (
selected_backend is None
and cls.is_device_capability(100)
and (block_size == 32 or block_size % 64 == 0)
)
use_flashmla = selected_backend == _Backend.FLASHMLA or (
selected_backend is None and is_flashmla_dense_supported()[0]
)
use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or (
selected_backend is None and flash_attn_supports_mla()
)
use_triton = selected_backend == _Backend.TRITON_MLA or (
selected_backend is None
)
if use_cutlassmla:
logger.info_once("Using Cutlass MLA backend.", scope="local")
return "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend"
if use_flashinfermla:
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout("HND")
logger.info_once("Using FlashInfer MLA backend.")
return (
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend"
)
if use_flashmla:
if block_size % 64 != 0:
logger.warning(
"FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).",
block_size,
)
else:
logger.info_once("Using FlashMLA backend.")
return "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend"
if use_flashattn:
logger.info_once("Using FlashAttention MLA backend.")
return (
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend"
)
if use_triton:
logger.info_once("Using Triton MLA backend.")
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501
FLEX_ATTENTION_V1 = (
"vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
# No selected backend or the selected backend is invalid,
# so we try finding a valid backend.
valid_backends_priorities, invalid_reasons = cls.get_valid_backends(
head_size,
dtype,
kv_cache_dtype,
None,
use_mla,
has_sink,
use_sparse,
device_capability,
)
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
TREE_ATTN_V1 = "vllm.v1.attention.backends.tree_attn.TreeAttentionBackend" # noqa: E501
XFORMERS_V1 = "vllm.v1.attention.backends.xformers.XFormersAttentionBackend" # noqa: E501
reasons_str = (
"{"
+ ", ".join(
f"{backend.name}: [{', '.join(reasons)}]"
for backend, reasons in invalid_reasons.items()
)
+ "}"
)
config_str = (
f"head_size: {head_size}, dtype: {dtype}, "
f"kv_cache_dtype: {kv_cache_dtype}, block_size: {block_size}, "
f"use_mla: {use_mla}, has_sink: {has_sink}, use_sparse: {use_sparse}"
)
logger.debug_once(
f"Some attention backends are not valid for {cls.device_name} with "
f"{config_str}. Reasons: {reasons_str}."
)
if len(valid_backends_priorities) == 0:
raise ValueError(
f"No valid attention backend found for {cls.device_name} "
f"with {config_str}. Reasons: {reasons_str}."
)
use_fp8_kv_cache = kv_cache_dtype is not None and kv_cache_dtype.startswith(
"fp8"
# We have found some valid backends. Select the one with the
# highest priority.
logger.info(
"Valid backends: %s", [b[0].name for b in valid_backends_priorities]
)
sorted_indices = sorted(
range(len(valid_backends_priorities)),
key=lambda i: valid_backends_priorities[i][1],
)
selected_index = sorted_indices[0]
selected_backend = valid_backends_priorities[selected_index][0]
logger.info(
"Using %s backend.",
selected_backend.name,
)
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend.")
if cls.has_device_capability(100):
from vllm.v1.attention.backends.utils import set_kv_cache_layout
set_kv_cache_layout("HND")
return FLASHINFER_V1
elif selected_backend == _Backend.FLEX_ATTENTION:
logger.info_once("Using FlexAttention backend.")
return FLEX_ATTENTION_V1
elif selected_backend == _Backend.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN_V1
elif selected_backend == _Backend.TREE_ATTN:
logger.info_once("Using Tree Attention backend.")
return TREE_ATTN_V1
elif selected_backend == _Backend.XFORMERS:
logger.info_once("Using XFormers backend.")
return XFORMERS_V1
from vllm.attention.selector import is_attn_backend_supported
# Default backends for V1 engine
# Prefer FlashInfer for Blackwell GPUs if installed
if cls.is_device_capability(100):
if is_default_backend_supported := is_attn_backend_supported(
FLASHINFER_V1, head_size, dtype
):
from vllm.v1.attention.backends.utils import set_kv_cache_layout
logger.info_once(
"Using FlashInfer backend with HND KV cache layout on "
"V1 engine by default for Blackwell (SM 10.0) GPUs."
)
set_kv_cache_layout("HND")
return FLASHINFER_V1
if not is_default_backend_supported.can_import:
logger.warning_once(
"FlashInfer failed to import on Blackwell (SM 10.0) GPUs; "
"it is recommended to install FlashInfer for better "
"performance."
)
# FlashAttention is the default for SM 8.0+ GPUs
if cls.has_device_capability(80):
if (has_sink or use_fp8_kv_cache) and not cls.is_device_capability(90):
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif is_default_backend_supported := is_attn_backend_supported(
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False
):
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN_V1
# FlexAttention is the default for older GPUs
else:
logger.info_once("Using FlexAttention backend.")
return FLEX_ATTENTION_V1
assert not is_default_backend_supported
use_flex_attention_reason = {}
if not is_default_backend_supported.head_size:
use_flex_attention_reason["head_size"] = head_size
if not is_default_backend_supported.dtype:
use_flex_attention_reason["dtype"] = dtype
logger.info_once(
"Using FlexAttention backend for %s.",
", ".join(f"{k}={v}" for k, v in use_flex_attention_reason.items()),
)
return FLEX_ATTENTION_V1
return selected_backend.get_path()
@classmethod
def get_punica_wrapper(cls) -> str:

View File

@@ -17,8 +17,9 @@ from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
@@ -58,6 +59,31 @@ class DeviceCapability(NamedTuple):
major: int
minor: int
def __lt__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) < (other.major, other.minor)
def __le__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) <= (other.major, other.minor)
def __eq__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) == (other.major, other.minor)
def __ge__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) >= (other.major, other.minor)
def __gt__(self, other: Any) -> bool:
if not isinstance(other, DeviceCapability):
return NotImplemented
return (self.major, self.minor) > (other.major, other.minor)
def as_version_str(self) -> str:
return f"{self.major}.{self.minor}"
@@ -173,19 +199,21 @@ class Platform:
import vllm._moe_C # noqa: F401
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
# Import _Backend here to avoid circular import.
from vllm.attention.backends.registry import _Backend
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return _Backend.TORCH_SDPA
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "_Backend",
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
kv_cache_dtype: "CacheDType | None",
block_size: int,
use_v1: bool,
use_mla: bool,

View File

@@ -14,10 +14,10 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
_Backend = None
AttentionBackendEnum = None
logger = init_logger(__name__)
@@ -204,21 +204,23 @@ class RocmPlatform(Platform):
]
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
# TODO: Add support for other VL models in their model class.
return _Backend.ROCM_AITER_FA
return AttentionBackendEnum.ROCM_AITER_FA
if on_gfx9() and find_spec("flash_attn") is not None:
return _Backend.FLASH_ATTN
return AttentionBackendEnum.FLASH_ATTN
return _Backend.TORCH_SDPA
return AttentionBackendEnum.TORCH_SDPA
@classmethod
def get_attn_backend_cls(
@@ -234,7 +236,7 @@ class RocmPlatform(Platform):
use_sparse,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on ROCm.")
@@ -248,55 +250,52 @@ class RocmPlatform(Platform):
if use_mla:
if selected_backend is None:
selected_backend = (
_Backend.ROCM_AITER_MLA
AttentionBackendEnum.ROCM_AITER_MLA
if rocm_aiter_ops.is_mla_enabled() or block_size == 1
else _Backend.TRITON_MLA
else AttentionBackendEnum.TRITON_MLA
)
if selected_backend == _Backend.TRITON_MLA:
if selected_backend == AttentionBackendEnum.TRITON_MLA:
if block_size != 1:
logger.info_once("Using Triton MLA backend.")
return "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend"
return AttentionBackendEnum.TRITON_MLA.get_path()
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"does not support block size {block_size}."
)
if selected_backend == _Backend.ROCM_AITER_MLA:
if selected_backend == AttentionBackendEnum.ROCM_AITER_MLA:
logger.info("Using AITER MLA backend.")
return "vllm.v1.attention.backends.mla.rocm_aiter_mla.AiterMLABackend" # noqa: E501
return AttentionBackendEnum.ROCM_AITER_MLA.get_path()
raise ValueError(
f" The selected backend, {selected_backend.name},"
f"is not MLA type while requested for MLA backend."
)
if selected_backend == _Backend.FLEX_ATTENTION:
if selected_backend == AttentionBackendEnum.FLEX_ATTENTION:
logger.info("Using FlexAttention backend.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend"
if (
rocm_aiter_ops.is_mha_enabled()
) or selected_backend == _Backend.ROCM_AITER_FA:
) or selected_backend == AttentionBackendEnum.ROCM_AITER_FA:
logger.info("Using Aiter Flash Attention backend.")
return "vllm.v1.attention.backends.rocm_aiter_fa.AiterFlashAttentionBackend"
return AttentionBackendEnum.ROCM_AITER_FA.get_path()
if (
rocm_aiter_ops.is_triton_unified_attn_enabled()
) or selected_backend == _Backend.ROCM_AITER_UNIFIED_ATTN:
) or selected_backend == AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN:
logger.info("Using Aiter Unified Attention backend.")
return (
"vllm.v1.attention.backends."
"rocm_aiter_unified_attn.RocmAiterUnifiedAttentionBackend"
)
return AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path()
if (
envs.VLLM_V1_USE_PREFILL_DECODE_ATTENTION
or selected_backend == _Backend.ROCM_ATTN
or selected_backend == AttentionBackendEnum.ROCM_ATTN
):
# rocm specific backend, with aiter and/or
# triton prefix-prefill
logger.info("Using Rocm Attention backend.")
return "vllm.v1.attention.backends.rocm_attn.RocmAttentionBackend"
return AttentionBackendEnum.ROCM_ATTN.get_path()
# default case, using triton unified attention
logger.info("Using Triton Attention backend.")
return "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend"
return AttentionBackendEnum.TRITON_ATTN.get_path()
@classmethod
def set_device(cls, device: torch.device) -> None:

View File

@@ -15,16 +15,15 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
else:
BlockSize = None
ModelConfig = None
VllmConfig = None
PoolingParams = None
_Backend = None
AttentionBackendEnum = None
logger = init_logger(__name__)
@@ -54,7 +53,7 @@ class TpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "_Backend",
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
@@ -64,17 +63,17 @@ class TpuPlatform(Platform):
has_sink,
use_sparse,
) -> str:
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != _Backend.PALLAS:
if selected_backend != AttentionBackendEnum.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
if not use_v1:
raise ValueError("TPU backend only supports V1.")
logger.info("Using Pallas V1 backend.")
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
return AttentionBackendEnum.PALLAS.get_path()
@classmethod
def set_device(cls, device: torch.device) -> None:

View File

@@ -14,12 +14,11 @@ from vllm.utils import DEFAULT_MAX_NUM_BATCHED_TOKENS
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import _Backend
from vllm.config import ModelConfig, VllmConfig
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
ModelConfig = None
VllmConfig = None
_Backend = None
AttentionBackendEnum = None
logger = init_logger(__name__)
@@ -44,7 +43,7 @@ class XPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(
cls,
selected_backend: "_Backend",
selected_backend: "AttentionBackendEnum",
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: str | None,
@@ -62,18 +61,19 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
from vllm.attention.backends.registry import _Backend
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
TRITON_ATTN = "vllm.v1.attention.backends.triton_attn.TritonAttentionBackend" # noqa: E501
FLASH_ATTN = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN:
use_v1 = envs.VLLM_USE_V1
if not use_v1:
raise ValueError("XPU backend only supports V1.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
logger.info_once("Using Triton backend.")
return TRITON_ATTN
elif selected_backend == _Backend.FLASH_ATTN:
return AttentionBackendEnum.TRITON_ATTN.get_path()
elif selected_backend == AttentionBackendEnum.FLASH_ATTN:
logger.info_once("Using Flash Attention backend.")
return FLASH_ATTN
return AttentionBackendEnum.FLASH_ATTN.get_path()
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}, "
@@ -81,7 +81,7 @@ class XPUPlatform(Platform):
)
logger.info("Using Flash Attention backend.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
return AttentionBackendEnum.FLASH_ATTN.get_path()
@classmethod
def set_device(cls, device: torch.device) -> None:
@@ -113,10 +113,10 @@ class XPUPlatform(Platform):
return device_props.total_memory
@classmethod
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> _Backend:
from vllm.attention.backends.registry import _Backend
return _Backend.FLASH_ATTN
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> AttentionBackendEnum:
return AttentionBackendEnum.FLASH_ATTN
@classmethod
def inference_mode(cls):