# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Iterable from typing import TYPE_CHECKING, Any import torch from vllm.config import VllmConfig from vllm.distributed.kv_events import ( BlockStored, KVCacheEvent, KVConnectorKVEvents, KVEventAggregator, ) from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole, SupportsHMA, ) from vllm.logger import init_logger from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.forward_context import ForwardContext from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.request import Request logger = init_logger(__name__) class LMCacheKVEvents(KVConnectorKVEvents): """ Concrete implementation of KVConnectorKVEvents using KVEventAggregator. """ def __init__(self, num_workers: int) -> None: self._aggregator = KVEventAggregator(num_workers) def add_events(self, events: list[KVCacheEvent]) -> None: self._aggregator.add_events(events) def aggregate(self) -> "LMCacheKVEvents": """ Aggregate KV events and retain only common events. """ common_events = self._aggregator.get_common_events() self._aggregator.clear_events() self._aggregator.add_events(common_events) self._aggregator.reset_workers() return self def increment_workers(self, count: int = 1) -> None: self._aggregator.increment_workers(count) def get_all_events(self) -> list[KVCacheEvent]: return self._aggregator.get_all_events() def get_number_of_workers(self) -> int: return self._aggregator.get_number_of_workers() def clear_events(self) -> None: self._aggregator.clear_events() self._aggregator.reset_workers() def __repr__(self) -> str: return f"" class LMCacheConnectorV1(KVConnectorBase_V1, SupportsHMA): @classmethod def requires_piecewise_for_cudagraph(cls, extra_config: dict[str, Any]) -> bool: """ LMCache requires PIECEWISE CUDA graph mode when layerwise operations are enabled. The wait_for_layer_load and save_kv_layer methods perform actual async synchronization that cannot be captured in CUDA graphs. """ return extra_config.get("use_layerwise", False) def __init__( self, vllm_config: "VllmConfig", role: KVConnectorRole, kv_cache_config: "KVCacheConfig", ): super().__init__( vllm_config=vllm_config, role=role, kv_cache_config=kv_cache_config ) assert vllm_config.kv_transfer_config is not None use_native = vllm_config.kv_transfer_config.get_from_extra_config( "use_native", False ) if use_native: logger.info("Initializing native LMCache connector") # lazy import from vllm.distributed.kv_transfer.kv_connector.v1 import lmcache_integration _adapter = lmcache_integration.vllm_v1_adapter cls = _adapter.LMCacheConnectorV1Impl else: logger.info("Initializing latest dev LMCache connector") # lazy import from lmcache.integration.vllm.vllm_v1_adapter import ( LMCacheConnectorV1Impl as LMCacheConnectorLatestImpl, ) cls = LMCacheConnectorLatestImpl self._lmcache_engine = cls(vllm_config, role, self) self._kv_cache_events: LMCacheKVEvents | None = None # ============================== # Worker-side methods # ============================== def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """ Initialize with the KV caches. Useful for pre-registering the KV Caches in the KVConnector (e.g. for NIXL). Args: kv_caches: dictionary of layer names, kv cache """ if hasattr(self._lmcache_engine, "register_kv_caches"): self._lmcache_engine.register_kv_caches(kv_caches) else: logger.warning( "LMCache engine does not support register_kv_caches, " "please check and use the latest version" ) def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged KV buffer. This is called from the forward context before the forward pass to enable async loading during model execution. Args: forward_context (ForwardContext): the forward context. **kwargs: additional arguments for the load operation Note: The number of elements in kv_caches and layer_names should be the same. """ self._lmcache_engine.start_load_kv(forward_context, **kwargs) def wait_for_layer_load(self, layer_name: str) -> None: """ Block until the KV for a specific layer is loaded into vLLM's paged buffer. This is called from within attention layer to ensure async copying from start_load_kv is complete. This interface will be useful for layer-by-layer pipelining. Args: layer_name: the name of that layer """ self._lmcache_engine.wait_for_layer_load(layer_name) def save_kv_layer( self, layer_name: str, kv_layer: torch.Tensor, attn_metadata: AttentionMetadata, **kwargs: Any, ) -> None: """ Start saving the a layer of KV cache from vLLM's paged buffer to the connector. This is called from within attention layer to enable async copying during execution. Args: layer_name (str): the name of the layer. kv_layer (torch.Tensor): the paged KV buffer of the current layer in vLLM. attn_metadata (AttentionMetadata): the attention metadata. **kwargs: additional arguments for the save operation. """ self._lmcache_engine.save_kv_layer( layer_name, kv_layer, attn_metadata, **kwargs ) def wait_for_save(self): """ Block until all the save operations is done. This is called as the forward context exits to ensure that the async saving from save_kv_layer is complete before finishing the forward. This prevents overwrites of paged KV buffer before saving done. """ self._lmcache_engine.wait_for_save() def get_finished( self, finished_req_ids: set[str] ) -> tuple[set[str] | None, set[str] | None]: """ Notifies worker-side connector ids of requests that have finished generating tokens. Returns: ids of requests that have finished asynchronous transfer (requests that previously returned True from request_finished()), tuple of (sending/saving ids, recving/loading ids). The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). """ return self._lmcache_engine.get_finished(finished_req_ids) def get_block_ids_with_load_errors(self) -> set[int]: """ Get the set of block IDs that failed to load. Returns: Set of block IDs that encountered load errors. Empty set if no load errors occurred. """ method = getattr(self._lmcache_engine, "get_block_ids_with_load_errors", None) if callable(method): return method() # Fallback for older versions that don't support this method return set() def get_kv_connector_kv_cache_events(self) -> LMCacheKVEvents | None: """ Get the KV connector kv cache events collected during the last interval. """ events = self._lmcache_engine.get_kv_events() # type: ignore [attr-defined] if not events: return None blocks: list[BlockStored] = [ BlockStored( block_hashes=e.block_hashes, parent_block_hash=e.parent_block_hash, token_ids=e.token_ids, lora_id=e.lora_id, block_size=e.block_size, medium=e.medium, lora_name=getattr(e, "lora_name", None), ) for e in events ] lmcache_kv_events = LMCacheKVEvents(num_workers=1) lmcache_kv_events.add_events(blocks) return lmcache_kv_events # ============================== # Scheduler-side methods # ============================== def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int, ) -> tuple[int | None, bool]: """ Get number of new tokens that can be loaded from the external KV cache beyond the num_computed_tokens. Args: request (Request): the request object. num_computed_tokens (int): the number of locally computed tokens for this request Returns: the number of tokens that can be loaded from the external KV cache beyond what is already computed. """ return self._lmcache_engine.get_num_new_matched_tokens( request, num_computed_tokens ), False def update_state_after_alloc( self, request: "Request", blocks: "KVCacheBlocks", num_external_tokens: int ): """ Update KVConnector state after block allocation. """ self._lmcache_engine.update_state_after_alloc(request, num_external_tokens) def build_connector_meta( self, scheduler_output: SchedulerOutput ) -> KVConnectorMetadata: """ Build the connector metadata for this step. This function should NOT modify fields in the scheduler_output. Also, calling this function will reset the state of the connector. Args: scheduler_output (SchedulerOutput): the scheduler output object. """ return self._lmcache_engine.build_connector_meta(scheduler_output) def update_connector_output(self, connector_output: KVConnectorOutput): """ Update KVConnector state from worker-side connectors output. Args: connector_output (KVConnectorOutput): the worker-side connectors output. """ # Get the KV events kv_cache_events = connector_output.kv_cache_events if not kv_cache_events or not isinstance(kv_cache_events, LMCacheKVEvents): return if self._kv_cache_events is None: self._kv_cache_events = kv_cache_events else: self._kv_cache_events.add_events(kv_cache_events.get_all_events()) self._kv_cache_events.increment_workers( kv_cache_events.get_number_of_workers() ) return def request_finished( self, request: "Request", block_ids: list[int], ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. Returns: True if the request is being saved/sent asynchronously and blocks should not be freed until the request_id is returned from get_finished(). Optional KVTransferParams to be included in the request outputs returned by the engine. """ return self._lmcache_engine.request_finished(request, block_ids) def request_finished_all_groups( self, request: "Request", block_ids: tuple[list[int], ...], ) -> tuple[bool, dict[str, Any] | None]: """ Called exactly once when a request has finished for all KV cache groups (HMA support for hybrid Mamba/Attention models). LMCache only stores/offloads attention KV cache blocks, so we extract the first group's block IDs and delegate to the single-group request_finished. Args: request: the request object. block_ids: tuple of block ID lists, one per KV cache group. Returns: Same as request_finished. """ # LMCache only handles attention (first) group blocks. # Mamba SSM state is managed separately by the scheduler. return self.request_finished(request, block_ids[0]) def take_events(self) -> Iterable["KVCacheEvent"]: """ Take the KV cache events from the connector. Yields: New KV cache events since the last call. """ if self._kv_cache_events is not None: self._kv_cache_events.aggregate() kv_cache_events = self._kv_cache_events.get_all_events() yield from kv_cache_events self._kv_cache_events.clear_events() self._kv_cache_events = None