diff --git a/tests/kernels/mamba/test_causal_conv1d.py b/tests/kernels/mamba/test_causal_conv1d.py index d16205694..039f2fc06 100644 --- a/tests/kernels/mamba/test_causal_conv1d.py +++ b/tests/kernels/mamba/test_causal_conv1d.py @@ -7,12 +7,12 @@ import torch import torch.nn.functional as F from einops import rearrange -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( causal_conv1d_fn, causal_conv1d_update, ) from vllm.utils.torch_utils import set_random_seed +from vllm.v1.attention.backends.utils import PAD_SLOT_ID def causal_conv1d_ref( diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index f50ab5344..905207109 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -8,12 +8,12 @@ from einops import rearrange, repeat from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update, ) from vllm.utils.torch_utils import set_random_seed +from vllm.v1.attention.backends.utils import PAD_SLOT_ID def selective_state_update_ref( diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py deleted file mode 100644 index 4c7fa477b..000000000 --- a/vllm/attention/backends/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Attention backend utils""" - -from dataclasses import dataclass - -from vllm.config import ModelConfig -from vllm.logger import init_logger - -logger = init_logger(__name__) - -PAD_SLOT_ID = -1 - - -@dataclass -class MLADims: - q_lora_rank: int | None - kv_lora_rank: int - qk_nope_head_dim: int - qk_rope_head_dim: int - v_head_dim: int - - -def get_mla_dims(model_config: ModelConfig) -> MLADims: - hf_text_config = model_config.hf_text_config - - return MLADims( - q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), - kv_lora_rank=hf_text_config.kv_lora_rank, - qk_nope_head_dim=hf_text_config.qk_nope_head_dim, - qk_rope_head_dim=hf_text_config.qk_rope_head_dim, - v_head_dim=hf_text_config.v_head_dim, - ) diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py index 83c2c5f11..157f9f346 100644 --- a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -8,8 +8,8 @@ import numpy as np import torch -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import tl, triton +from vllm.v1.attention.backends.utils import PAD_SLOT_ID @triton.jit() diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 800f8bd84..628ad970c 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -8,8 +8,8 @@ import torch from packaging import version from vllm import _custom_ops as ops -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import HAS_TRITON, tl, triton +from vllm.v1.attention.backends.utils import PAD_SLOT_ID TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 79636ecab..96f0d20ac 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -7,9 +7,9 @@ from dataclasses import dataclass import torch from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.config import VllmConfig from vllm.v1.attention.backends.utils import ( + PAD_SLOT_ID, AttentionCGSupport, AttentionMetadataBuilder, CommonAttentionMetadata, diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 4805bf2ee..2ee2740a5 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -205,11 +205,10 @@ from vllm.attention.backends.abstract import ( AttentionMetadata, MLAAttentionImpl, ) -from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.common import cp_lse_ag_out_rs from vllm.attention.ops.merge_attn_states import merge_attn_states from vllm.attention.utils.fa_utils import get_flash_attn_version -from vllm.config import VllmConfig, get_current_vllm_config +from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank from vllm.logger import init_logger from vllm.model_executor.layers.batch_invariant import ( @@ -479,6 +478,27 @@ def use_trtllm_ragged_deepseek_prefill() -> bool: ) +@dataclass +class MLADims: + q_lora_rank: int | None + kv_lora_rank: int + qk_nope_head_dim: int + qk_rope_head_dim: int + v_head_dim: int + + +def get_mla_dims(model_config: ModelConfig) -> MLADims: + hf_text_config = model_config.hf_text_config + + return MLADims( + q_lora_rank=getattr(hf_text_config, "q_lora_rank", None), + kv_lora_rank=hf_text_config.kv_lora_rank, + qk_nope_head_dim=hf_text_config.qk_nope_head_dim, + qk_rope_head_dim=hf_text_config.qk_rope_head_dim, + v_head_dim=hf_text_config.v_head_dim, + ) + + class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]): """ NOTE: Please read the comment at the top of the file before trying to diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py index 64cca2888..dec92d2d4 100644 --- a/vllm/v1/attention/backends/mla/flashmla_sparse.py +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import ( AttentionMetadata, MultipleOf, ) -from vllm.attention.backends.utils import get_mla_dims from vllm.attention.ops.flashmla import ( flash_mla_sparse_prefill, flash_mla_with_kvcache, @@ -26,7 +25,7 @@ from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv -from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl +from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py index a461a2155..e68e80e86 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py @@ -14,12 +14,9 @@ from vllm.attention.backends.abstract import ( AttentionLayer, AttentionMetadata, ) -from vllm.attention.backends.utils import get_mla_dims from vllm.config import VllmConfig from vllm.logger import init_logger -from vllm.v1.attention.backends.mla.common import ( - MLACommonBaseImpl, -) +from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims from vllm.v1.attention.backends.mla.flashmla_sparse import ( triton_convert_req_index_to_global_index, ) diff --git a/vllm/v1/worker/gpu/block_table.py b/vllm/v1/worker/gpu/block_table.py index b31e9b179..ee18401f6 100644 --- a/vllm/v1/worker/gpu/block_table.py +++ b/vllm/v1/worker/gpu/block_table.py @@ -4,9 +4,9 @@ from collections.abc import Iterable import torch -from vllm.attention.backends.utils import PAD_SLOT_ID from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv +from vllm.v1.attention.backends.utils import PAD_SLOT_ID from vllm.v1.utils import CpuGpuBuffer