monkey patch the monkey pathing vllm nonsense
This commit is contained in:
379
lmcache_connector.py
Normal file
379
lmcache_connector.py
Normal file
@@ -0,0 +1,379 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import (
|
||||
BlockStored,
|
||||
KVCacheEvent,
|
||||
KVConnectorKVEvents,
|
||||
KVEventAggregator,
|
||||
)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1,
|
||||
KVConnectorMetadata,
|
||||
KVConnectorRole,
|
||||
SupportsHMA,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LMCacheKVEvents(KVConnectorKVEvents):
|
||||
"""
|
||||
Concrete implementation of KVConnectorKVEvents using KVEventAggregator.
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int) -> None:
|
||||
self._aggregator = KVEventAggregator(num_workers)
|
||||
|
||||
def add_events(self, events: list[KVCacheEvent]) -> None:
|
||||
self._aggregator.add_events(events)
|
||||
|
||||
def aggregate(self) -> "LMCacheKVEvents":
|
||||
"""
|
||||
Aggregate KV events and retain only common events.
|
||||
"""
|
||||
common_events = self._aggregator.get_common_events()
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.add_events(common_events)
|
||||
self._aggregator.reset_workers()
|
||||
return self
|
||||
|
||||
def increment_workers(self, count: int = 1) -> None:
|
||||
self._aggregator.increment_workers(count)
|
||||
|
||||
def get_all_events(self) -> list[KVCacheEvent]:
|
||||
return self._aggregator.get_all_events()
|
||||
|
||||
def get_number_of_workers(self) -> int:
|
||||
return self._aggregator.get_number_of_workers()
|
||||
|
||||
def clear_events(self) -> None:
|
||||
self._aggregator.clear_events()
|
||||
self._aggregator.reset_workers()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<LMCacheKVEvents events={self.get_all_events()}>"
|
||||
|
||||
|
||||
class LMCacheConnectorV1(KVConnectorBase_V1, SupportsHMA):
|
||||
@classmethod
|
||||
def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool:
|
||||
"""
|
||||
LMCache requires PIECEWISE CUDA graph mode when layerwise
|
||||
operations are enabled. The wait_for_layer_load and save_kv_layer
|
||||
methods perform actual async synchronization that cannot be
|
||||
captured in CUDA graphs.
|
||||
"""
|
||||
return extra_config.get("use_layerwise", False)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: "KVCacheConfig",
|
||||
):
|
||||
super().__init__(
|
||||
vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config
|
||||
)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
use_native = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"use_native", False
|
||||
)
|
||||
if use_native:
|
||||
logger.info("Initializing native LMCache connector")
|
||||
# lazy import
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import lmcache_integration
|
||||
|
||||
_adapter = lmcache_integration.vllm_v1_adapter
|
||||
|
||||
cls = _adapter.LMCacheConnectorV1Impl
|
||||
else:
|
||||
logger.info("Initializing latest dev LMCache connector")
|
||||
# lazy import
|
||||
from lmcache.integration.vllm.vllm_v1_adapter import (
|
||||
LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl,
|
||||
)
|
||||
|
||||
cls = LMCacheConnectorLatestImpl
|
||||
|
||||
self._lmcache_engine = cls(vllm_config, role, self)
|
||||
|
||||
self._kv_cache_events: LMCacheKVEvents | None = None
|
||||
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
if hasattr(self._lmcache_engine, "register_kv_caches"):
|
||||
self._lmcache_engine.register_kv_caches(kv_caches)
|
||||
else:
|
||||
logger.warning(
|
||||
"LMCache engine does not support register_kv_caches, "
|
||||
"please check and use the latest version"
|
||||
)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
KV buffer. This is called from the forward context before the
|
||||
forward pass to enable async loading during model execution.
|
||||
|
||||
Args:
|
||||
forward_context (ForwardContext): the forward context.
|
||||
**kwargs: additional arguments for the load operation
|
||||
|
||||
Note:
|
||||
The number of elements in kv_caches and layer_names should be
|
||||
the same.
|
||||
|
||||
"""
|
||||
self._lmcache_engine.start_load_kv(forward_context, **kwargs)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
"""
|
||||
Block until the KV for a specific layer is loaded into vLLM's
|
||||
paged buffer. This is called from within attention layer to ensure
|
||||
async copying from start_load_kv is complete.
|
||||
|
||||
This interface will be useful for layer-by-layer pipelining.
|
||||
|
||||
Args:
|
||||
layer_name: the name of that layer
|
||||
"""
|
||||
self._lmcache_engine.wait_for_layer_load(layer_name)
|
||||
|
||||
def save_kv_layer(
|
||||
self,
|
||||
layer_name: str,
|
||||
kv_layer: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Start saving the a layer of KV cache from vLLM's paged buffer
|
||||
to the connector. This is called from within attention layer to
|
||||
enable async copying during execution.
|
||||
|
||||
Args:
|
||||
layer_name (str): the name of the layer.
|
||||
kv_layer (torch.Tensor): the paged KV buffer of the current
|
||||
layer in vLLM.
|
||||
attn_metadata (AttentionMetadata): the attention metadata.
|
||||
**kwargs: additional arguments for the save operation.
|
||||
"""
|
||||
self._lmcache_engine.save_kv_layer(
|
||||
layer_name, kv_layer, attn_metadata, **kwargs
|
||||
)
|
||||
|
||||
def wait_for_save(self):
|
||||
"""
|
||||
Block until all the save operations is done. This is called
|
||||
as the forward context exits to ensure that the async saving
|
||||
from save_kv_layer is complete before finishing the forward.
|
||||
|
||||
This prevents overwrites of paged KV buffer before saving done.
|
||||
"""
|
||||
self._lmcache_engine.wait_for_save()
|
||||
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
(requests that previously returned True from request_finished()),
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
The finished saves/sends req ids must belong to a set provided in a
|
||||
call to this method (this call or a prior one).
|
||||
"""
|
||||
return self._lmcache_engine.get_finished(finished_req_ids)
|
||||
|
||||
def get_block_ids_with_load_errors(self) -> set[int]:
|
||||
"""
|
||||
Get the set of block IDs that failed to load.
|
||||
|
||||
Returns:
|
||||
Set of block IDs that encountered load errors.
|
||||
Empty set if no load errors occurred.
|
||||
"""
|
||||
method = getattr(self._lmcache_engine, "get_block_ids_with_load_errors", None)
|
||||
if callable(method):
|
||||
return method()
|
||||
|
||||
# Fallback for older versions that don't support this method
|
||||
return set()
|
||||
|
||||
def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None:
|
||||
"""
|
||||
Get the KV connector kv cache events collected during the last interval.
|
||||
"""
|
||||
|
||||
events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined]
|
||||
if not events:
|
||||
return None
|
||||
|
||||
blocks: list[BlockStored] = [
|
||||
BlockStored(
|
||||
block_hashes=e.block_hashes,
|
||||
parent_block_hash=e.parent_block_hash,
|
||||
token_ids=e.token_ids,
|
||||
lora_id=e.lora_id,
|
||||
block_size=e.block_size,
|
||||
medium=e.medium,
|
||||
lora_name=getattr(e, "lora_name", None),
|
||||
)
|
||||
for e in events
|
||||
]
|
||||
|
||||
lmcache_kv_events = LMCacheKVEvents(num_workers=1)
|
||||
lmcache_kv_events.add_events(blocks)
|
||||
return lmcache_kv_events
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
def get_num_new_matched_tokens(
|
||||
self,
|
||||
request: "Request",
|
||||
num_computed_tokens: int,
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded from the
|
||||
external KV cache beyond the num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
the number of tokens that can be loaded from the
|
||||
external KV cache beyond what is already computed.
|
||||
"""
|
||||
return self._lmcache_engine.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
), False
|
||||
|
||||
def update_state_after_alloc(
|
||||
self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int
|
||||
):
|
||||
"""
|
||||
Update KVConnector state after block allocation.
|
||||
"""
|
||||
self._lmcache_engine.update_state_after_alloc(request, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput
|
||||
) -> KVConnectorMetadata:
|
||||
"""
|
||||
Build the connector metadata for this step.
|
||||
|
||||
This function should NOT modify fields in the scheduler_output.
|
||||
Also, calling this function will reset the state of the connector.
|
||||
|
||||
Args:
|
||||
scheduler_output (SchedulerOutput): the scheduler output object.
|
||||
"""
|
||||
return self._lmcache_engine.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
# Get the KV events
|
||||
kv_cache_events = connector_output.kv_cache_events
|
||||
if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents):
|
||||
return
|
||||
|
||||
if self._kv_cache_events is None:
|
||||
self._kv_cache_events = kv_cache_events
|
||||
else:
|
||||
self._kv_cache_events.add_events(kv_cache_events.get_all_events())
|
||||
self._kv_cache_events.increment_workers(
|
||||
kv_cache_events.get_number_of_workers()
|
||||
)
|
||||
return
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
return self._lmcache_engine.request_finished(request, block_ids)
|
||||
|
||||
def request_finished_all_groups(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: tuple[list[int], ...],
|
||||
) -> tuple[bool, dict[str, Any] | None]:
|
||||
"""
|
||||
Called exactly once when a request has finished for all KV cache
|
||||
groups (HMA support for hybrid Mamba/Attention models).
|
||||
|
||||
LMCache only stores/offloads attention KV cache blocks, so we
|
||||
extract the first group's block IDs and delegate to the
|
||||
single-group request_finished.
|
||||
|
||||
Args:
|
||||
request: the request object.
|
||||
block_ids: tuple of block ID lists, one per KV cache group.
|
||||
|
||||
Returns:
|
||||
Same as request_finished.
|
||||
"""
|
||||
# LMCache only handles attention (first) group blocks.
|
||||
# Mamba SSM state is managed separately by the scheduler.
|
||||
return self.request_finished(request, block_ids[0])
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
"""
|
||||
Take the KV cache events from the connector.
|
||||
|
||||
Yields:
|
||||
New KV cache events since the last call.
|
||||
"""
|
||||
if self._kv_cache_events is not None:
|
||||
self._kv_cache_events.aggregate()
|
||||
kv_cache_events = self._kv_cache_events.get_all_events()
|
||||
yield from kv_cache_events
|
||||
self._kv_cache_events.clear_events()
|
||||
self._kv_cache_events = None
|
||||
Reference in New Issue
Block a user