implements register kv caches in lmcache connector (#31397)
Signed-off-by: idellzheng <idellzheng@tencent.com>
This commit is contained in:
@@ -107,6 +107,22 @@ class LMCacheConnectorV1(KVConnectorBase_V1):
|
||||
# ==============================
|
||||
# Worker-side methods
|
||||
# ==============================
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Initialize with the KV caches. Useful for pre-registering the
|
||||
KV Caches in the KVConnector (e.g. for NIXL).
|
||||
|
||||
Args:
|
||||
kv_caches: dictionary of layer names, kv cache
|
||||
"""
|
||||
if hasattr(self._lmcache_engine, "register_kv_caches"):
|
||||
self._lmcache_engine.register_kv_caches(kv_caches)
|
||||
else:
|
||||
logger.warning(
|
||||
"LMCache engine does not support register_kv_caches, "
|
||||
"please check and use the latest version"
|
||||
)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
Start loading the KV cache from the connector to vLLM's paged
|
||||
|
||||
@@ -782,6 +782,16 @@ class LMCacheConnectorV1Impl:
|
||||
####################
|
||||
# Worker side APIs
|
||||
####################
|
||||
@_lmcache_nvtx_annotate
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
logger.info("Registering KV caches")
|
||||
# TODO(chunxiaozheng): `_init_kv_caches_from_forward_context` is
|
||||
# not called, we should consider removing it.
|
||||
assert len(self.kv_caches) == 0 and len(kv_caches) > 0
|
||||
self.kv_caches = kv_caches
|
||||
if self.lmcache_engine is not None:
|
||||
kvcaches = list(self.kv_caches.values())
|
||||
self.lmcache_engine.post_init(kvcaches=kvcaches)
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
|
||||
Reference in New Issue
Block a user