[Attention] Update attention imports (#29540)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-11-27 11:19:09 -05:00
committed by GitHub
parent cd007a53b4
commit fc1d8be3dc
38 changed files with 63 additions and 126 deletions

View File

@@ -14,6 +14,7 @@ import regex as re
import torch
from vllm import envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import CpuArchEnum, Platform, PlatformEnum
@@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
logger = init_logger(__name__)
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
VllmConfig = None
@@ -135,8 +134,6 @@ class CpuPlatform(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
logger.info("Cannot use %s backend on CPU.", selected_backend)
if use_mla:

View File

@@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
# import custom ops, trigger op registration
import vllm._C # noqa
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.import_utils import import_pynvml
from vllm.utils.torch_utils import cuda_device_count_stateless
@@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
else:
AttentionBackendEnum = None
VllmConfig = None
CacheDType = None
@@ -48,8 +48,6 @@ def _get_backend_priorities(
device_capability: DeviceCapability,
) -> list[AttentionBackendEnum]:
"""Get backend priorities with lazy import to avoid circular dependency."""
from vllm.attention.backends.registry import AttentionBackendEnum
if use_mla:
if device_capability.major == 10:
return [
@@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
# Try FlashAttention first
try:
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
@@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
use_sparse: bool,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.abstract import AttentionType
if attn_type is None:
attn_type = AttentionType.DECODER

View File

@@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
import numpy as np
import torch
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
if TYPE_CHECKING:
from torch.distributed import PrefixStore, ProcessGroup
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.inputs import ProcessorInputs, PromptType
@@ -226,9 +226,6 @@ class Platform:
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
# Import AttentionBackendEnum here to avoid circular import.
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.TORCH_SDPA
@classmethod

View File

@@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from vllm.utils.torch_utils import cuda_device_count_stateless
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
AttentionBackendEnum = None
logger = init_logger(__name__)
@@ -196,7 +194,6 @@ class RocmPlatform(Platform):
from importlib.util import find_spec
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if rocm_aiter_ops.is_mha_enabled():
# Note: AITER FA is only supported for Qwen-VL models.
@@ -222,7 +219,6 @@ class RocmPlatform(Platform):
attn_type: str | None = None,
) -> str:
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
if kv_cache_dtype.startswith("fp8"):

View File

@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
import torch
from tpu_info import device
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.inputs import ProcessorInputs, PromptType
from vllm.logger import init_logger
@@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
if TYPE_CHECKING:
from typing import TypeAlias
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
from vllm.config.cache import BlockSize
from vllm.pooling_params import PoolingParams
@@ -26,7 +26,6 @@ else:
BlockSize = None
VllmConfig = None
PoolingParams = None
AttentionBackendEnum = None
ParamsType = None
logger = init_logger(__name__)
@@ -67,8 +66,6 @@ class TpuPlatform(Platform):
use_sparse,
attn_type: str | None = None,
) -> str:
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on TPU.")
if selected_backend != AttentionBackendEnum.PALLAS:

View File

@@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
import torch
import vllm.envs as envs
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.logger import init_logger
from .interface import DeviceCapability, Platform, PlatformEnum
if TYPE_CHECKING:
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.config import VllmConfig
else:
VllmConfig = None
AttentionBackendEnum = None
logger = init_logger(__name__)
@@ -60,8 +59,6 @@ class XPUPlatform(Platform):
"only NHD layout is supported by XPU attention kernels."
)
from vllm.attention.backends.registry import AttentionBackendEnum
if use_sparse:
raise NotImplementedError("Sparse Attention is not supported on XPU.")
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
@@ -116,8 +113,6 @@ class XPUPlatform(Platform):
def get_vit_attn_backend(
cls, head_size: int, dtype: torch.dtype
) -> "AttentionBackendEnum":
from vllm.attention.backends.registry import AttentionBackendEnum
return AttentionBackendEnum.FLASH_ATTN
@classmethod