[KVConnector]: Enable Cross-layers KV cache layout for MultiConnector (#30761)

Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com>
This commit is contained in:
Kfir Toledo
2026-01-07 18:59:43 +02:00
committed by GitHub
parent 1d9e9ae8a4
commit b89443b8d9
4 changed files with 71 additions and 11 deletions

View File

@@ -49,6 +49,33 @@ class MockConnector(KVConnectorBase_V1):
) -> KVConnectorStats | None:
return MockConnectorStats(data=data) if data is not None else None
def start_load_kv(self, forward_context, **kwargs):
pass
def wait_for_layer_load(self, layer_name):
pass
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
pass
def wait_for_save(self):
pass
def build_connector_meta(self, scheduler_output):
return None
def get_num_new_matched_tokens(self, request, num_computed_tokens):
return (0, False)
def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
pass
class MockCrossLayerConnector(MockConnector):
@property
def prefer_cross_layer_blocks(self) -> bool:
return True
# Register the mock connector
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
@@ -601,3 +628,21 @@ class TestMultiConnectorStats:
# One non-empty
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
assert not stats.is_empty()
class TestMultiConnectorPreferCrossLayerBlocks:
def test_all_connectors_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
]
assert mc.prefer_cross_layer_blocks is True
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
mc = MultiConnector.__new__(MultiConnector)
mc._connectors = [
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
MockConnector.__new__(MockConnector), # default False
]
assert mc.prefer_cross_layer_blocks is False

View File

@@ -38,7 +38,7 @@ The class provides the following primitives:
import enum
from abc import ABC, abstractmethod
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
from typing import TYPE_CHECKING, Any, Literal, Optional
import torch
@@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024
class KVConnectorBase_V1(ABC):
"""
Base class for KV connectors.
Attributes:
prefer_cross_layer_blocks (bool): Indicates whether this connector
prefers KV blocks that hold KV data for all layers (for speeding
up KV data transfers).
Defaults to False.
"""
prefer_cross_layer_blocks: ClassVar[bool] = False
@property
def prefer_cross_layer_blocks(self) -> bool:
"""
Indicates whether this connector prefers KV blocks that hold KV data for all
layers, which can speed up KV data transfers. Defaults to False.
"""
return False
def __init__(
self,

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
from vllm.config import VllmConfig
from vllm.config.kv_transfer import KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
@@ -138,6 +138,12 @@ class MultiConnector(KVConnectorBase_V1):
# Propagated from scheduler to worker side via the connector metadata.
self._extra_async_saves: dict[str, int] = {}
@property
def prefer_cross_layer_blocks(self) -> bool:
if not self._connectors:
return False
return all(c.prefer_cross_layer_blocks for c in self._connectors)
@classmethod
def _get_connector_classes_and_configs(
cls, vllm_config: "VllmConfig"
@@ -164,6 +170,13 @@ class MultiConnector(KVConnectorBase_V1):
)
return ret
def register_cross_layers_kv_cache(
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
):
# Register on all connectors
for c in self._connectors:
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
for c in self._connectors:
c.register_kv_caches(kv_caches)

View File

@@ -4,7 +4,7 @@ from collections import defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from itertools import islice
from typing import Any, ClassVar
from typing import Any
import torch
@@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
class OffloadingConnector(KVConnectorBase_V1):
prefer_cross_layer_blocks: ClassVar[bool] = True
@property
def prefer_cross_layer_blocks(self) -> bool:
return True
def __init__(
self,