[4/N][Attention] Move MLA common to model_executor (#32060)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -19,12 +19,12 @@ from tests.v1.attention.utils import (
|
||||
)
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config.vllm import set_current_vllm_config
|
||||
from vllm.model_executor.layers.attention.mla_attention import QueryLenSupport
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
|
||||
from vllm.v1.attention.backend import CommonAttentionMetadata
|
||||
from vllm.v1.attention.backends.fa_utils import flash_attn_supports_mla
|
||||
from vllm.v1.attention.backends.mla.common import QueryLenSupport
|
||||
from vllm.v1.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.v1.attention.ops.flashmla import is_flashmla_dense_supported
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec
|
||||
|
||||
@@ -14,9 +14,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorRole,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -18,8 +18,8 @@ from vllm.distributed.kv_transfer.kv_connector.v1.p2p.p2p_nccl_engine import (
|
||||
)
|
||||
from vllm.distributed.parallel_state import get_world_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadata
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -33,8 +33,6 @@ class NewLineFormatter(logging.Formatter):
|
||||
model_executor/.../quantization/utils/fp8_utils.py
|
||||
vllm/model_executor/layers/quantization/awq.py ->
|
||||
model_executor/layers/quantization/awq.py
|
||||
vllm/v1/attention/backends/mla/common.py ->
|
||||
v1/attention/backends/mla/common.py
|
||||
|
||||
Args:
|
||||
relpath (Path): The relative path to be shortened.
|
||||
|
||||
@@ -9,6 +9,12 @@ import torch
|
||||
import vllm._custom_ops as ops
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
@@ -17,12 +23,6 @@ from vllm.v1.attention.backend import (
|
||||
MultipleOf,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@@ -9,6 +9,14 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
@@ -24,14 +32,6 @@ from vllm.v1.attention.backends.fa_utils import (
|
||||
flash_attn_supports_mla,
|
||||
get_flash_attn_version,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
flash_attn_varlen_func,
|
||||
|
||||
@@ -8,6 +8,13 @@ from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
@@ -15,13 +22,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import KVCacheLayoutType
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -9,6 +9,14 @@ import torch
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
@@ -19,14 +27,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionType,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
|
||||
@@ -10,6 +10,10 @@ from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBaseImpl,
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.triton_utils import tl, triton
|
||||
@@ -23,7 +27,6 @@ from vllm.v1.attention.backend import (
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
reshape_attn_output_for_spec_decode,
|
||||
reshape_query_for_spec_decode,
|
||||
|
||||
@@ -8,8 +8,7 @@ import torch
|
||||
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonDecodeMetadata,
|
||||
MLACommonImpl,
|
||||
@@ -17,6 +16,7 @@ from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonMetadataBuilder,
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,10 @@ from vllm import _custom_ops as ops
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBaseImpl,
|
||||
get_mla_dims,
|
||||
)
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionBackend,
|
||||
AttentionCGSupport,
|
||||
@@ -19,7 +23,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
|
||||
from vllm.v1.attention.backends.mla.flashmla_sparse import (
|
||||
triton_convert_req_index_to_global_index,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,11 @@ import torch
|
||||
|
||||
from vllm.config.cache import CacheDType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
@@ -16,11 +21,6 @@ from vllm.v1.attention.backend import (
|
||||
AttentionType,
|
||||
is_quantized_kv_cache,
|
||||
)
|
||||
from vllm.v1.attention.backends.mla.common import (
|
||||
MLACommonBackend,
|
||||
MLACommonImpl,
|
||||
MLACommonMetadata,
|
||||
)
|
||||
from vllm.v1.attention.ops.triton_decode_attention import decode_attention_fwd
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -183,7 +183,9 @@ class EagleProposer:
|
||||
rocm_types.append(AiterFlashAttentionMetadata)
|
||||
|
||||
# TRITON_MLA backend support for MLA models (e.g., DeepSeek)
|
||||
from vllm.v1.attention.backends.mla.common import MLACommonMetadata
|
||||
from vllm.model_executor.layers.attention.mla_attention import (
|
||||
MLACommonMetadata,
|
||||
)
|
||||
|
||||
rocm_types.append(MLACommonMetadata)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user