[4/N][Attention] Move MLA common to model_executor (#32060)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Matthew Bonanni
2026-01-13 12:08:45 -05:00
committed by GitHub
parent 4f3676e726
commit 2263d44b68
14 changed files with 50 additions and 44 deletions

View File

@@ -19,12 +19,12 @@ from tests.v1.attention.utils import (
)
from vllm import _custom_ops as ops
from vllm.config.vllm import set_current_vllm_config
from vllm.model_executor.layers.attention.mla_attention import QueryLenSupport
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.attention.backend import CommonAttentionMetadata
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
from vllm.v1.attention.backends.mla.common import QueryLenSupport
from vllm.v1.attention.backends.registry import AttentionBackendEnum
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
from vllm.v1.kv_cache_interface import FullAttentionSpec

View File

@@ -14,9 +14,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorRole,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.utils.hashing import safe_hash
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:

View File

@@ -18,8 +18,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
)
from vllm.distributed.parallel_state import get_world_group
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.v1.core.sched.output import SchedulerOutput
if TYPE_CHECKING:

View File

@@ -33,8 +33,6 @@ class NewLineFormatter(logging.Formatter):
model_executor/.../quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/awq.py ->
model_executor/layers/quantization/awq.py
vllm/v1/attention/backends/mla/common.py ->
v1/attention/backends/mla/common.py
Args:
relpath (Path): The relative path to be shortened.

View File

@@ -9,6 +9,12 @@ import torch
import vllm._custom_ops as ops
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
@@ -17,12 +23,6 @@ from vllm.v1.attention.backend import (
MultipleOf,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
)
logger = init_logger(__name__)

View File

@@ -9,6 +9,14 @@ import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
@@ -24,14 +32,6 @@ from vllm.v1.attention.backends.fa_utils import (
flash_attn_supports_mla,
get_flash_attn_version,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
flash_attn_varlen_func,

View File

@@ -8,6 +8,13 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.platforms.interface import DeviceCapability
from vllm.v1.attention.backend import (
AttentionCGSupport,
@@ -15,13 +22,6 @@ from vllm.v1.attention.backend import (
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import KVCacheLayoutType
logger = init_logger(__name__)

View File

@@ -9,6 +9,14 @@ import torch
from vllm.config import VllmConfig
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
@@ -19,14 +27,6 @@ from vllm.v1.attention.backend import (
AttentionType,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,

View File

@@ -10,6 +10,10 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims,
)
from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
@@ -23,7 +27,6 @@ from vllm.v1.attention.backend import (
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import (
reshape_attn_output_for_spec_decode,
reshape_query_for_spec_decode,

View File

@@ -8,8 +8,7 @@ import torch
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.attention.backends.mla.common import (
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonDecodeMetadata,
MLACommonImpl,
@@ -17,6 +16,7 @@ from vllm.v1.attention.backends.mla.common import (
MLACommonMetadataBuilder,
QueryLenSupport,
)
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.kv_cache_interface import AttentionSpec

View File

@@ -11,6 +11,10 @@ from vllm import _custom_ops as ops
from vllm._aiter_ops import rocm_aiter_ops
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBaseImpl,
get_mla_dims,
)
from vllm.v1.attention.backend import (
AttentionBackend,
AttentionCGSupport,
@@ -19,7 +23,6 @@ from vllm.v1.attention.backend import (
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)

View File

@@ -7,6 +7,11 @@ import torch
from vllm.config.cache import CacheDType
from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
)
from vllm.model_executor.layers.batch_invariant import (
vllm_is_batch_invariant,
)
@@ -16,11 +21,6 @@ from vllm.v1.attention.backend import (
AttentionType,
is_quantized_kv_cache,
)
from vllm.v1.attention.backends.mla.common import (
MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
)
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
logger = init_logger(__name__)

View File

@@ -183,7 +183,9 @@ class EagleProposer:
rocm_types.append(AiterFlashAttentionMetadata)
# TRITON_MLA backend support for MLA models (e.g., DeepSeek)
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
from vllm.model_executor.layers.attention.mla_attention import (
MLACommonMetadata,
)
rocm_types.append(MLACommonMetadata)