[Attention] Update attention imports (#29540)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2025-11-27 11:19:09 -05:00
committed by GitHub
parent cd007a53b4
commit fc1d8be3dc
38 changed files with 63 additions and 126 deletions

View File

@@ -3,14 +3,11 @@
"""Base class for attention-like layers."""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from vllm.attention.backends.abstract import AttentionBackend
from vllm.config import VllmConfig
from vllm.v1.kv_cache_interface import KVCacheSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class AttentionLayerBase(ABC):
"""
@@ -22,7 +19,7 @@ class AttentionLayerBase(ABC):
"""
@abstractmethod
def get_attn_backend(self) -> type["AttentionBackend"]:
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this layer."""
pass

View File

@@ -2,18 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import get_mamba_attn_backend
from vllm.config import VllmConfig
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.v1.kv_cache_interface import KVCacheSpec, MambaSpec
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
class MambaBase(AttentionLayerBase):
"""
@@ -66,6 +63,6 @@ class MambaBase(AttentionLayerBase):
),
)
def get_attn_backend(self) -> type["AttentionBackend"]:
def get_attn_backend(self) -> type[AttentionBackend]:
"""Get the attention backend class for this Mamba layer."""
return get_mamba_attn_backend(self.mamba_type)

View File

@@ -18,6 +18,7 @@ from compressed_tensors.quantization import (
from compressed_tensors.transform import TransformConfig
import vllm.envs as envs
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -131,8 +132,6 @@ class CompressedTensorsConfig(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
# collect schemes
quant_scheme = self.get_scheme(layer=layer, layer_name=prefix)

View File

@@ -14,6 +14,7 @@ import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.layer import Attention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@@ -277,7 +278,6 @@ class Fp8Config(QuantizationConfig):
def get_xpu_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention
from vllm.model_executor.layers.quantization.ipex_quant import (
XPUFp8LinearMethod,
XPUFp8MoEMethod,
@@ -307,8 +307,6 @@ class Fp8Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if current_platform.is_xpu():
return self.get_xpu_quant_method(layer, prefix)
if isinstance(layer, LinearBase):

View File

@@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm._custom_ops import cutlass_scaled_fp4_mm, scaled_fp4_quant
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
@@ -149,8 +150,6 @@ class ModelOptQuantConfigBase(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# handle kv-cache first so we can focus only on weight quantization thereafter
if isinstance(layer, Attention):
return self.KVCacheMethodCls(self)

View File

@@ -8,6 +8,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import envs
from vllm.attention.layer import Attention
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (
@@ -184,8 +185,6 @@ class Mxfp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if self.ignored_layers and is_layer_skipped(
prefix=prefix,

View File

@@ -8,6 +8,7 @@ import regex as re
import torch
from torch.nn.parameter import Parameter
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (
LinearBase,
@@ -159,8 +160,6 @@ class PetitNvFp4Config(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
exclude = self.require_exclude_modules()
if isinstance(layer, LinearBase):

View File

@@ -7,6 +7,7 @@ import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -65,8 +66,6 @@ class PTPCFp8Config(Fp8Config):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
return UnquantizedLinearMethod()

View File

@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional, cast
import torch
from vllm.attention.layer import Attention
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
@@ -102,8 +103,6 @@ class QuarkConfig(QuantizationConfig):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
# Check if the layer is skipped for quantization.
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
if should_ignore_layer(