[Chore] Migrate V0 attention utils (#31891)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-07 21:44:36 +08:00
committed by GitHub
parent 974138751b
commit b665bbc2d4
10 changed files with 30 additions and 47 deletions

View File

@@ -7,9 +7,9 @@ from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.config import VllmConfig
from vllm.v1.attention.backends.utils import (
PAD_SLOT_ID,
AttentionCGSupport,
AttentionMetadataBuilder,
CommonAttentionMetadata,

View File

@@ -205,11 +205,10 @@ from vllm.attention.backends.abstract import (
AttentionMetadata,
MLAAttentionImpl,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.common import cp_lse_ag_out_rs
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config import ModelConfig, VllmConfig, get_current_vllm_config
from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.batch_invariant import (
@@ -479,6 +478,27 @@ def use_trtllm_ragged_deepseek_prefill() -> bool:
)
@dataclass
class MLADims:
q_lora_rank: int | None
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int
def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config
return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
"""
NOTE: Please read the comment at the top of the file before trying to

View File

@@ -13,7 +13,6 @@ from vllm.attention.backends.abstract import (
AttentionMetadata,
MultipleOf,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.flashmla import (
flash_mla_sparse_prefill,
flash_mla_with_kvcache,
@@ -26,7 +25,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,

View File

@@ -14,12 +14,9 @@ from vllm.attention.backends.abstract import (
AttentionLayer,
AttentionMetadata,
)
from vllm.attention.backends.utils import get_mla_dims
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (
MLACommonBaseImpl,
)
from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl, get_mla_dims
from vllm.v1.attention.backends.mla.flashmla_sparse import (
triton_convert_req_index_to_global_index,
)