[Misc][PD] Fix get_attn_backend usage in transfer connectors (#31988)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@@ -1439,7 +1439,7 @@ def test_register_kv_caches(default_vllm_config, dist_init, attn_backend):
|
||||
patch(f"{nixl_module}.NixlWrapper") as mock_nixl_wrapper,
|
||||
patch(f"{nixl_module}.threading.Event"),
|
||||
patch(f"{nixl_module}.threading.Thread") as mock_thread,
|
||||
patch(f"{nixl_module}.get_attn_backend") as mock_get_attn_backend,
|
||||
patch(f"{nixl_module}.get_current_attn_backend") as mock_get_attn_backend,
|
||||
):
|
||||
# Ensure get_attn_backend returns the correct value due to
|
||||
# _cached_get_attn_backend returning the backend from previous
|
||||
|
||||
@@ -6,13 +6,14 @@ KV cache helper for store.
|
||||
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.v1.attention.backend import AttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
@@ -433,3 +434,26 @@ class TpKVTopology:
|
||||
) -> list[int]:
|
||||
remote_tp_size = self.remote_tp_size[remote_engine_id]
|
||||
return self.get_target_remote_ranks(remote_tp_size)
|
||||
|
||||
|
||||
def get_current_attn_backend(vllm_config: VllmConfig):
|
||||
layer_type = cast(type[Any], AttentionLayerBase)
|
||||
layers = get_layers_from_vllm_config(vllm_config, layer_type, None)
|
||||
if layers:
|
||||
backend = next(iter(layers.values())).get_attn_backend()
|
||||
else:
|
||||
# Fallback for tests, when static_forward_context is empty.
|
||||
logger.debug(
|
||||
"No layers found in the vLLM config. "
|
||||
"Falling back to default attention backend."
|
||||
)
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
|
||||
backend = get_attn_backend(
|
||||
head_size=vllm_config.model_config.get_head_size(),
|
||||
dtype=vllm_config.model_config.dtype,
|
||||
kv_cache_dtype=vllm_config.cache_config.cache_dtype,
|
||||
block_size=vllm_config.cache_config.block_size,
|
||||
use_mla=vllm_config.model_config.use_mla,
|
||||
)
|
||||
return backend
|
||||
|
||||
@@ -16,7 +16,10 @@ import zmq.asyncio
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import TpKVTopology
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
TpKVTopology,
|
||||
get_current_attn_backend,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
@@ -32,7 +35,6 @@ from vllm.logger import init_logger
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
@@ -468,13 +470,9 @@ class MooncakeConnectorWorker:
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.use_mla = self.model_config.use_mla
|
||||
|
||||
backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
# Get the attention backend from the first layer
|
||||
# NOTE (NickLucche) models with multiple backends are not supported yet
|
||||
backend = get_current_attn_backend(vllm_config)
|
||||
self.backend_name = backend.get_name()
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
logger.debug("Detected attention backend %s", self.backend_name)
|
||||
|
||||
@@ -24,6 +24,7 @@ from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import (
|
||||
EngineId,
|
||||
TpKVTopology,
|
||||
get_current_attn_backend,
|
||||
kv_postprocess_blksize_and_layout_on_receive,
|
||||
kv_postprocess_blksize_on_receive,
|
||||
kv_postprocess_layout_on_receive,
|
||||
@@ -53,7 +54,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.attention.selector import get_attn_backend
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
|
||||
@@ -957,13 +957,10 @@ class NixlConnectorWorker:
|
||||
self.block_window_per_layer: list[int | None] = []
|
||||
self.use_mla = self.model_config.use_mla
|
||||
|
||||
backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
# Get the attention backend from the first layer
|
||||
# NOTE (NickLucche) models with multiple backends are not supported yet
|
||||
backend = get_current_attn_backend(vllm_config)
|
||||
|
||||
self.backend_name = backend.get_name()
|
||||
self.kv_cache_layout = get_kv_cache_layout()
|
||||
self.host_buffer_kv_cache_layout = self.kv_cache_layout
|
||||
|
||||
Reference in New Issue
Block a user