diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py index 17d468fe6..fe9f3d785 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py @@ -107,6 +107,22 @@ class LMCacheConnectorV1(KVConnectorBase_V1): # ============================== # Worker-side methods # ============================== + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Initialize with the KV caches. Useful for pre-registering the + KV Caches in the KVConnector (e.g. for NIXL). + + Args: + kv_caches: dictionary of layer names, kv cache + """ + if hasattr(self._lmcache_engine, "register_kv_caches"): + self._lmcache_engine.register_kv_caches(kv_caches) + else: + logger.warning( + "LMCache engine does not support register_kv_caches, " + "please check and use the latest version" + ) + def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None: """ Start loading the KV cache from the connector to vLLM's paged diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py index 09af128f3..b9db6a168 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py @@ -782,6 +782,16 @@ class LMCacheConnectorV1Impl: #################### # Worker side APIs #################### + @_lmcache_nvtx_annotate + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + logger.info("Registering KV caches") + # TODO(chunxiaozheng): `_init_kv_caches_from_forward_context` is + # not called, we should consider removing it. + assert len(self.kv_caches) == 0 and len(kv_caches) > 0 + self.kv_caches = kv_caches + if self.lmcache_engine is not None: + kvcaches = list(self.kv_caches.values()) + self.lmcache_engine.post_init(kvcaches=kvcaches) @_lmcache_nvtx_annotate def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: