[1/N][Attention] Restructure attention: move files (#31916)

Signed-off-by: Matthew Bonanni <mbonanni@redhat.com>
This commit is contained in:
Matthew Bonanni
2026-01-09 16:10:24 -05:00
committed by GitHub
parent 1f8b7c536b
commit 2612ba9285
195 changed files with 426 additions and 396 deletions

View File

@@ -33,9 +33,7 @@ from torch import nn
from transformers import DeepseekV2Config, DeepseekV3Config
from vllm._aiter_ops import rocm_aiter_ops
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.layer import Attention
from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed import (
@@ -78,10 +76,12 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionBackend
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerBackend,
DeepseekV32IndexerMetadata,
)
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
from vllm.v1.kv_cache_interface import KVCacheSpec, MLAAttentionSpec
from vllm.v1.worker.workspace import current_workspace_manager
@@ -679,7 +679,9 @@ def sparse_attn_indexer(
)
fp8_mqa_logits_func = fp8_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_fp8_mqa_logits,
)
fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func(
@@ -729,7 +731,7 @@ def sparse_attn_indexer(
num_padded_tokens = batch_size * next_n
fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import (
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_fp8_paged_mqa_logits,
)