diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index bc2afd2c3..462d5802e 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -1439,7 +1439,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend): patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper, patch(f"{nixl_module}.threading.Event"), patch(f"{nixl_module}.threading.Thread") as mock_thread, - patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend, + patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend, ): # Ensure get_attn_backend returns the correct value due to # _cached_get_attn_backend returning the backend from previous diff --git a/vllm/distributed/kv_transfer/kv_connector/utils.py b/vllm/distributed/kv_transfer/kv_connector/utils.py index 39d3085ba..fd833e293 100644 --- a/vllm/distributed/kv_transfer/kv_connector/utils.py +++ b/vllm/distributed/kv_transfer/kv_connector/utils.py @@ -6,13 +6,14 @@ KV cache helper for store. from collections.abc import Iterator from dataclasses import dataclass -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Any, Literal, cast import torch -from vllm.config import get_current_vllm_config +from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory from vllm.logger import init_logger +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.v1.attention.backend import AttentionBackend from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput @@ -433,3 +434,26 @@ class TpKVTopology: ) -> list[int]: remote_tp_size = self.remote_tp_size[remote_engine_id] return self.get_target_remote_ranks(remote_tp_size) + + +def get_current_attn_backend(vllm_config: VllmConfig): + layer_type = cast(type[Any], AttentionLayerBase) + layers = get_layers_from_vllm_config(vllm_config, layer_type, None) + if layers: + backend = next(iter(layers.values())).get_attn_backend() + else: + # Fallback for tests, when static_forward_context is empty. + logger.debug( + "No layers found in the vLLM config. " + "Falling back to default attention backend." + ) + from vllm.v1.attention.selector import get_attn_backend + + backend = get_attn_backend( + head_size=vllm_config.model_config.get_head_size(), + dtype=vllm_config.model_config.dtype, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + block_size=vllm_config.cache_config.block_size, + use_mla=vllm_config.model_config.use_mla, + ) + return backend diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py index 56beda4e5..ef0268b9a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/mooncake_connector.py @@ -16,7 +16,10 @@ import zmq.asyncio from vllm import envs from vllm.config import VllmConfig -from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology +from vllm.distributed.kv_transfer.kv_connector.utils import ( + TpKVTopology, + get_current_attn_backend, +) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, @@ -32,7 +35,6 @@ from vllm.logger import init_logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout -from vllm.v1.attention.selector import get_attn_backend from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.request import RequestStatus @@ -468,13 +470,9 @@ class MooncakeConnectorWorker: self.cache_config = vllm_config.cache_config self.use_mla = self.model_config.use_mla - backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - use_mla=self.use_mla, - ) + # Get the attention backend from the first layer + # NOTE (NickLucche) models with multiple backends are not supported yet + backend = get_current_attn_backend(vllm_config) self.backend_name = backend.get_name() self.kv_cache_layout = get_kv_cache_layout() logger.debug("Detected attention backend %s", self.backend_name) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index 131cb3ec9..797af6d9c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -24,6 +24,7 @@ from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.utils import ( EngineId, TpKVTopology, + get_current_attn_backend, kv_postprocess_blksize_and_layout_on_receive, kv_postprocess_blksize_on_receive, kv_postprocess_layout_on_receive, @@ -53,7 +54,6 @@ from vllm.platforms import current_platform from vllm.utils.network_utils import make_zmq_path, make_zmq_socket from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.utils import get_kv_cache_layout -from vllm.v1.attention.selector import get_attn_backend from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.block_table import BlockTable @@ -957,13 +957,10 @@ class NixlConnectorWorker: self.block_window_per_layer: list[int | None] = [] self.use_mla = self.model_config.use_mla - backend = get_attn_backend( - self.model_config.get_head_size(), - self.model_config.dtype, - self.cache_config.cache_dtype, - self.block_size, - use_mla=self.use_mla, - ) + # Get the attention backend from the first layer + # NOTE (NickLucche) models with multiple backends are not supported yet + backend = get_current_attn_backend(vllm_config) + self.backend_name = backend.get_name() self.kv_cache_layout = get_kv_cache_layout() self.host_buffer_kv_cache_layout = self.kv_cache_layout