[Chore] Migrate V0 attention utils (#31891)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-07 21:44:36 +08:00
committed by GitHub
parent 974138751b
commit b665bbc2d4
10 changed files with 30 additions and 47 deletions

View File

@@ -7,12 +7,12 @@ import torch
import torch.nn.functional as F
from einops import rearrange
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn,
causal_conv1d_update,
)
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
def causal_conv1d_ref(

View File

@@ -8,12 +8,12 @@ from einops import rearrange, repeat
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn,
selective_state_update,
)
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
def selective_state_update_ref(

View File

@@ -1,33 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Attention backend utils"""
from dataclasses import dataclass
from vllm.config import ModelConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
PAD_SLOT_ID = -1
@dataclass
class MLADims:
q_lora_rank: int | None
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)

View File

@@ -8,8 +8,8 @@
import numpy as np
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
@triton.jit()

View File

@@ -8,8 +8,8 @@ import torch
from packaging import version
from vllm import _custom_ops as ops
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import HAS_TRITON, tl, triton
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0"))

View File

@@ -7,9 +7,9 @@ from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,

View File

@@ -205,11 +205,10 @@ from vllm.attention.backends.abstract import (
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@@ -479,6 +478,27 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
)
@dataclass
class MLADims:
q_lora_rank: int | None
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to

View File

@@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import (
AttentionMetadata,
MultipleOf,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import (
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
@@ -26,7 +25,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,

View File

@@ -14,12 +14,9 @@ from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionMetadata,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBaseImpl,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)

View File

@@ -4,9 +4,9 @@ from collections.abc import Iterable
import torch
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.utils import PAD_SLOT_ID
from vllm.v1.utils import CpuGpuBuffer