[KVConnector]: Enable Cross-layers KV cache layout for MultiConnector (#30761)
Signed-off-by: Kfir Toledo <kfir.toledo@ibm.com>
This commit is contained in:
@@ -49,6 +49,33 @@ class MockConnector(KVConnectorBase_V1):
|
||||
) -> KVConnectorStats | None:
|
||||
return MockConnectorStats(data=data) if data is not None else None
|
||||
|
||||
def start_load_kv(self, forward_context, **kwargs):
|
||||
pass
|
||||
|
||||
def wait_for_layer_load(self, layer_name):
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name, kv_layer, attn_metadata, **kwargs):
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
pass
|
||||
|
||||
def build_connector_meta(self, scheduler_output):
|
||||
return None
|
||||
|
||||
def get_num_new_matched_tokens(self, request, num_computed_tokens):
|
||||
return (0, False)
|
||||
|
||||
def update_state_after_alloc(self, request, blocks, num_tokens) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class MockCrossLayerConnector(MockConnector):
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
# Register the mock connector
|
||||
KVConnectorFactory.register_connector("MockConnector", __name__, MockConnector.__name__)
|
||||
@@ -601,3 +628,21 @@ class TestMultiConnectorStats:
|
||||
# One non-empty
|
||||
stats.data["NixlConnector"].data["transfer_duration"].append(1.0)
|
||||
assert not stats.is_empty()
|
||||
|
||||
|
||||
class TestMultiConnectorPreferCrossLayerBlocks:
|
||||
def test_all_connectors_prefer_cross_layer_blocks(self):
|
||||
mc = MultiConnector.__new__(MultiConnector)
|
||||
mc._connectors = [
|
||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
||||
]
|
||||
assert mc.prefer_cross_layer_blocks is True
|
||||
|
||||
def test_mixed_connectors_do_not_prefer_cross_layer_blocks(self):
|
||||
mc = MultiConnector.__new__(MultiConnector)
|
||||
mc._connectors = [
|
||||
MockCrossLayerConnector.__new__(MockCrossLayerConnector),
|
||||
MockConnector.__new__(MockConnector), # default False
|
||||
]
|
||||
assert mc.prefer_cross_layer_blocks is False
|
||||
|
||||
@@ -38,7 +38,7 @@ The class provides the following primitives:
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Iterable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -144,15 +144,15 @@ class KVConnectorMetadata(ABC): # noqa: B024
|
||||
class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
Base class for KV connectors.
|
||||
|
||||
Attributes:
|
||||
prefer_cross_layer_blocks (bool): Indicates whether this connector
|
||||
prefers KV blocks that hold KV data for all layers (for speeding
|
||||
up KV data transfers).
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
prefer_cross_layer_blocks: ClassVar[bool] = False
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
"""
|
||||
Indicates whether this connector prefers KV blocks that hold KV data for all
|
||||
layers, which can speed up KV data transfers. Defaults to False.
|
||||
"""
|
||||
return False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.abstract import AttentionBackend, AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBaseType
|
||||
@@ -138,6 +138,12 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
# Propagated from scheduler to worker side via the connector metadata.
|
||||
self._extra_async_saves: dict[str, int] = {}
|
||||
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
if not self._connectors:
|
||||
return False
|
||||
return all(c.prefer_cross_layer_blocks for c in self._connectors)
|
||||
|
||||
@classmethod
|
||||
def _get_connector_classes_and_configs(
|
||||
cls, vllm_config: "VllmConfig"
|
||||
@@ -164,6 +170,13 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
)
|
||||
return ret
|
||||
|
||||
def register_cross_layers_kv_cache(
|
||||
self, kv_cache: torch.Tensor, attn_backend: type[AttentionBackend]
|
||||
):
|
||||
# Register on all connectors
|
||||
for c in self._connectors:
|
||||
c.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for c in self._connectors:
|
||||
c.register_kv_caches(kv_caches)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any, ClassVar
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -44,7 +44,9 @@ class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
prefer_cross_layer_blocks: ClassVar[bool] = True
|
||||
@property
|
||||
def prefer_cross_layer_blocks(self) -> bool:
|
||||
return True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user