[Attention] Update attention imports (#29540)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
@@ -14,6 +14,7 @@ import regex as re
|
||||
import torch
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
@@ -21,10 +22,8 @@ from .interface import CpuArchEnum, Platform, PlatformEnum
|
||||
logger = init_logger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
|
||||
|
||||
@@ -135,8 +134,6 @@ class CpuPlatform(Platform):
|
||||
use_sparse: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if selected_backend and selected_backend != AttentionBackendEnum.CPU_ATTN:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
|
||||
@@ -15,6 +15,8 @@ from typing_extensions import ParamSpec
|
||||
# import custom ops, trigger op registration
|
||||
import vllm._C # noqa
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.import_utils import import_pynvml
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
@@ -22,11 +24,9 @@ from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
else:
|
||||
AttentionBackendEnum = None
|
||||
VllmConfig = None
|
||||
CacheDType = None
|
||||
|
||||
@@ -48,8 +48,6 @@ def _get_backend_priorities(
|
||||
device_capability: DeviceCapability,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
"""Get backend priorities with lazy import to avoid circular dependency."""
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_mla:
|
||||
if device_capability.major == 10:
|
||||
return [
|
||||
@@ -265,8 +263,6 @@ class CudaPlatformBase(Platform):
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
# Try FlashAttention first
|
||||
try:
|
||||
backend_class = AttentionBackendEnum.FLASH_ATTN.get_class()
|
||||
@@ -335,8 +331,6 @@ class CudaPlatformBase(Platform):
|
||||
use_sparse: bool,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
|
||||
if attn_type is None:
|
||||
attn_type = AttentionType.DECODER
|
||||
|
||||
|
||||
@@ -12,12 +12,12 @@ from typing import TYPE_CHECKING, Any, NamedTuple
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.distributed import PrefixStore, ProcessGroup
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
@@ -226,9 +226,6 @@ class Platform:
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
# Import AttentionBackendEnum here to avoid circular import.
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
return AttentionBackendEnum.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -8,16 +8,14 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -196,7 +194,6 @@ class RocmPlatform(Platform):
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
@@ -222,7 +219,6 @@ class RocmPlatform(Platform):
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, cast
|
||||
import torch
|
||||
from tpu_info import device
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -15,7 +16,6 @@ from .interface import Platform, PlatformEnum
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import BlockSize
|
||||
from vllm.pooling_params import PoolingParams
|
||||
@@ -26,7 +26,6 @@ else:
|
||||
BlockSize = None
|
||||
VllmConfig = None
|
||||
PoolingParams = None
|
||||
AttentionBackendEnum = None
|
||||
ParamsType = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -67,8 +66,6 @@ class TpuPlatform(Platform):
|
||||
use_sparse,
|
||||
attn_type: str | None = None,
|
||||
) -> str:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on TPU.")
|
||||
if selected_backend != AttentionBackendEnum.PALLAS:
|
||||
|
||||
@@ -8,16 +8,15 @@ from typing import TYPE_CHECKING
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .interface import DeviceCapability, Platform, PlatformEnum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import VllmConfig
|
||||
else:
|
||||
VllmConfig = None
|
||||
AttentionBackendEnum = None
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -60,8 +59,6 @@ class XPUPlatform(Platform):
|
||||
"only NHD layout is supported by XPU attention kernels."
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
if use_sparse:
|
||||
raise NotImplementedError("Sparse Attention is not supported on XPU.")
|
||||
if selected_backend == AttentionBackendEnum.TRITON_ATTN:
|
||||
@@ -116,8 +113,6 @@ class XPUPlatform(Platform):
|
||||
def get_vit_attn_backend(
|
||||
cls, head_size: int, dtype: torch.dtype
|
||||
) -> "AttentionBackendEnum":
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
|
||||
return AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user