From e94ec597334d9a3e9b0d04bc17152e2747c83d51 Mon Sep 17 00:00:00 2001 From: Yuwei An Date: Mon, 9 Feb 2026 17:18:42 -0800 Subject: [PATCH] [LMCache] Token Base IPC API (#34175) Signed-off-by: Oasis-Git --- .../multi_process_adapter.py | 417 +++++++++++++++--- .../kv_connector/v1/lmcache_mp_connector.py | 49 +- 2 files changed, 376 insertions(+), 90 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py index d865f70bd..e476cba7c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py @@ -20,16 +20,42 @@ from lmcache.v1.multiprocess.protocol import RequestType, get_response_class logger = init_logger(__name__) -def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache: +def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache: logger.info("KV caches keys are %s", list(kv_caches.keys())) return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()] +def striding_block_hashes( + block_hashes: list[bytes], blocks_in_chunk: int +) -> Iterable[bytes]: + """Extract chunk-level hashes from block hashes by striding. + + In hash-based vLLM, each vLLM block has its own hash. LMCache chunks + span ``blocks_in_chunk`` consecutive blocks. The representative hash + for a chunk is the hash of the **last** block in that chunk (because + each block hash already encodes its prefix). So we start at index + ``blocks_in_chunk - 1`` and stride by ``blocks_in_chunk``. + """ + return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk) + + def send_lmcache_request( mq_client: MessageQueueClient, request_type: RequestType, payloads: list[Any], ) -> MessagingFuture[Any]: + """ + Helper function to send the request to the LMCache multiprocess server + + Args: + mq_client: The LMCache multiprocess mode message queue client + request_type: The request type + payloads: The request payloads + + Returns: + A messaging future for the request + """ + future = mq_client.submit_request( request_type, payloads, get_response_class(request_type) ) @@ -39,40 +65,44 @@ def send_lmcache_request( def get_lmcache_chunk_size( mq_client: MessageQueueClient, ) -> int: + """ + Helper function to get the LMCache chunk size from the server + + Args: + mq_client: The LMCache multiprocess mode message queue client + + Returns: + An integer representing the LMCache chunk size + """ future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, []) chunk_size = future.result() return chunk_size -def striding_block_hashes( - block_hashes: list[bytes], - blocks_in_chunk, -) -> Iterable[bytes]: - """Striding the block hashes to get the block hashes for each chunk. - For example, if blocks_in_chunk is 16, then we will get the block hashes - for the 16th, 32nd, 48th, ... blocks. - """ - return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk) - - @dataclass class LoadStoreOp: - block_hashes: list[bytes] block_ids: list[int] + """Block ids for the load/store operation""" + + token_ids: list[int] | None = None + """Token IDs for the load/store operation (token mode)""" + + block_hashes: list[bytes] | None = None + """Block hashes for the load/store operation (hash mode)""" + + start: int = 0 + """Start token index (token mode only)""" + + end: int = 0 + """End token index (token mode only)""" def __len__(self) -> int: - return len(self.block_hashes) - - def __post_init__(self): - assert len(self.block_hashes) == len(self.block_ids), ( - "The number of block hashes should be equal to the number of block ids " - f"But got {len(self.block_hashes)} and {len(self.block_ids)}" - ) + return len(self.block_ids) StoreResult = bool RetrieveResult = list[bool] -LookupResult = list[bool] +LookupResult = int class LMCacheMPSchedulerAdapter: @@ -95,10 +125,6 @@ class LMCacheMPSchedulerAdapter: kv_rank: The kv rank used for LMCache keys vllm_block_size: The block size used in vLLM """ - logger.warning( - "Importing LMCacheMPSchedulerAdapter is deprecated. " - "Please update your LMCache to the latest version." - ) self.mq_client = MessageQueueClient(server_url, context) # Request futures @@ -116,22 +142,89 @@ class LMCacheMPSchedulerAdapter: self.blocks_in_chunk = self.chunk_size // vllm_block_size @_lmcache_nvtx_annotate - def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]): + def maybe_submit_lookup_request( + self, + request_id: str, + block_hashes: list[bytes] | None = None, + token_ids: list[int] | None = None, + ) -> None: + """ + Submit a new lookup request to LMCache if there is no ongoing request. + + Supports both token-based and hash-based vLLM: + - token_ids: token IDs (token-based vLLM) -> single token-mode key + - block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys + + Exactly one of block_hashes or token_ids must be provided. + + Args: + request_id: The ID of the lookup request. The same ID indicates it's + from the same request + block_hashes: Block hashes to lookup from LMCache (hash mode) + token_ids: Token IDs to lookup from LMCache (token mode) + + Returns: + None + + Notes: + This function will have a side-effect: submitting a look up request to + LMCache, which will essentially 'lock' the KV cache chunks in the LMCache + for later retrieve operations. + In the meantime, this function will record the lookup request, and the + status of the look up request can be checked by `check_lookup_result`. + """ if request_id in self.lookup_futures: # Skip if there is already a lookup request return - s = striding_block_hashes(block_hashes, self.blocks_in_chunk) - keys = [self._create_key(block_hash) for block_hash in s] + assert (block_hashes is None) != (token_ids is None), ( + "Exactly one of block_hashes or token_ids must be provided" + ) + + if block_hashes is not None: + # Hash mode: stride block hashes -> N hash-mode keys + chunk_hashes = list( + striding_block_hashes(block_hashes, self.blocks_in_chunk) + ) + keys = [ + self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes + ] + else: + # Token mode: truncate to chunk-aligned length + assert token_ids is not None + aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size + if aligned_end == 0: + return + keys = [ + self._create_key( + token_ids, + start=0, + end=aligned_end, + request_id=request_id, + ).no_worker_id_version() + ] + future = send_lmcache_request( self.mq_client, RequestType.LOOKUP, - [keys, True], + [keys], ) self.lookup_futures[request_id] = future @_lmcache_nvtx_annotate def check_lookup_result(self, request_id: str) -> int | None: + """ + Check the result of a previously submitted lookup request. + + Args: + request_id: The ID of the lookup request submitted in + `maybe_submit_lookup_request` + + Returns: + An integer representing the total number of tokens matched + in LMCache (prefix matching), or + None if the lookup request is not finished yet. + """ assert request_id in self.lookup_futures, ( f"Lookup request for request_id={request_id} has not been submitted" ) @@ -141,7 +234,7 @@ class LMCacheMPSchedulerAdapter: return None result = future.result() - num_chunks = sum(result) + num_chunks = result return num_chunks * self.chunk_size def num_blocks_per_chunk(self) -> int: @@ -159,14 +252,47 @@ class LMCacheMPSchedulerAdapter: """ self.lookup_futures.pop(request_id, None) + def end_session(self, request_id: str) -> None: + """ + Notify LMCache server to remove the session for a finished request. + Args: + request_id: The ID of the finished request. + """ + send_lmcache_request( + self.mq_client, + RequestType.END_SESSION, + [request_id], + ) + # Helper functions - def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey: - """Convert a block hash to an IPC cache engine key""" + def _create_key( + self, + token_ids: list[int], + start: int = 0, + end: int = 0, + request_id: str | None = None, + ) -> IPCCacheEngineKey: + """Convert token IDs to an IPC cache engine key""" return IPCCacheEngineKey( model_name=self.model_name, world_size=self.world_size, worker_id=self.worker_id, - chunk_hash=block_hash, + token_ids=tuple(token_ids), + start=start, + end=end, + request_id=request_id, + ) + + def _create_hash_key( + self, chunk_hash: bytes, request_id: str | None = None + ) -> IPCCacheEngineKey: + """Create a hash-mode IPC cache engine key""" + return IPCCacheEngineKey( + model_name=self.model_name, + world_size=self.world_size, + worker_id=None, + chunk_hash=chunk_hash, + request_id=request_id, ) @@ -180,10 +306,6 @@ class LMCacheMPWorkerAdapter: kv_rank: int, vllm_block_size: int, ): - logger.warning( - "Importing LMCacheMPWorkerAdapter is deprecated. " - "Please update your LMCache to the latest version." - ) self.mq_client = MessageQueueClient(server_url, context) # Instance id for GPU worker @@ -201,7 +323,10 @@ class LMCacheMPWorkerAdapter: str, tuple[MessagingFuture[RetrieveResult], list[str]] ] = {} + # The store requests that have finished execution in LMCache self.finished_stores: set[str] = set() + # The finished request ids that are passed via vLLM and also + # have corresponding store requests submitted to LMCache before self.previously_finished: set[str] = set() self.model_name = model_name @@ -215,7 +340,14 @@ class LMCacheMPWorkerAdapter: ) self.blocks_in_chunk = chunk_size // vllm_block_size - def register_kv_caches(self, kv_caches: dict[str, KVCache]): + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + """ + Register the kv caches with LMCache server + + Args: + kv_caches: A dict of kv caches to register. The keys are the + layer names and the values are the corresponding tensors. + """ # Register kv cache and send the request self.kv_caches = kv_caches logger.info("Registering kv caches") @@ -230,7 +362,29 @@ class LMCacheMPWorkerAdapter: def submit_store_request( self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event ): - keys = self._block_hashes_to_keys(op.block_hashes) + """ + Submit a KV cache store request to LMCache + + Args: + request_id: The ID of the request + op: The LoadStoreOp describing the store operation. + event: The CUDA event that is recorded after the current + model inference step + """ + if op.block_hashes is not None: + # Hash mode + chunk_hashes = list( + striding_block_hashes(op.block_hashes, self.blocks_in_chunk) + ) + keys = [ + self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes + ] + else: + # Token mode + assert op.token_ids is not None + keys = [ + self._create_key(op.token_ids, op.start, op.end, request_id=request_id) + ] future = send_lmcache_request( self.mq_client, RequestType.STORE, @@ -242,7 +396,29 @@ class LMCacheMPWorkerAdapter: def submit_retrieve_request( self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event ): - keys = self._block_hashes_to_keys(op.block_hashes) + """ + Submit a KV cache retrieve request to LMCache + + Args: + request_id: The ID of the request + op: The LoadStoreOp describing the retrieve operation. + event: The CUDA event that is recorded after the current + model inference step + """ + if op.block_hashes is not None: + # Hash mode + chunk_hashes = list( + striding_block_hashes(op.block_hashes, self.blocks_in_chunk) + ) + keys = [ + self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes + ] + else: + # Token mode + assert op.token_ids is not None + keys = [ + self._create_key(op.token_ids, op.start, op.end, request_id=request_id) + ] future = send_lmcache_request( self.mq_client, RequestType.RETRIEVE, @@ -257,17 +433,47 @@ class LMCacheMPWorkerAdapter: ops: list[LoadStoreOp], event: torch.cuda.Event, ): - keys = [] - block_ids = [] - for op in ops: - keys.extend(self._block_hashes_to_keys(op.block_hashes)) + """ + Submit a batched store request to LMCache + + Args: + request_ids: The IDs of the requests + ops: The LoadStoreOps describing the store operations. Should have + the same length as request_ids + event: The CUDA event that is recorded after the current + model inference step + """ + all_keys: list[IPCCacheEngineKey] = [] + block_ids: list[int] = [] + for request_id, op in zip(request_ids, ops, strict=False): + if op.block_hashes is not None: + chunk_hashes = list( + striding_block_hashes(op.block_hashes, self.blocks_in_chunk) + ) + keys = [ + self._create_hash_key(ch, request_id=request_id) + for ch in chunk_hashes + ] + all_keys.extend(keys) + else: + assert op.token_ids is not None + all_keys.append( + self._create_key( + op.token_ids, op.start, op.end, request_id=request_id + ) + ) block_ids.extend(op.block_ids) future = send_lmcache_request( self.mq_client, RequestType.STORE, - [keys, self.instance_id, block_ids, event.ipc_handle()], + [ + all_keys, + self.instance_id, + block_ids, + event.ipc_handle(), + ], ).to_cuda_future() - self.store_futures[request_ids[0]] = (future, request_ids[1:]) + self.store_futures[request_ids[0]] = (future, list(request_ids[1:])) @_lmcache_nvtx_annotate def batched_submit_retrieve_requests( @@ -276,34 +482,83 @@ class LMCacheMPWorkerAdapter: ops: list[LoadStoreOp], event: torch.cuda.Event, ): - keys = [] - block_ids = [] + """ + Submit a batched retrieve request to LMCache - for op in ops: - keys.extend(self._block_hashes_to_keys(op.block_hashes)) + Args: + request_ids: The IDs of the requests + ops: The LoadStoreOps describing the retrieve operations. Should have + the same length as request_ids + event: The CUDA event that is recorded after the current + model inference step + """ + all_keys: list[IPCCacheEngineKey] = [] + block_ids: list[int] = [] + for request_id, op in zip(request_ids, ops, strict=False): + if op.block_hashes is not None: + chunk_hashes = list( + striding_block_hashes(op.block_hashes, self.blocks_in_chunk) + ) + keys = [ + self._create_hash_key(ch, request_id=request_id) + for ch in chunk_hashes + ] + all_keys.extend(keys) + else: + assert op.token_ids is not None + all_keys.append( + self._create_key( + op.token_ids, op.start, op.end, request_id=request_id + ) + ) block_ids.extend(op.block_ids) future = send_lmcache_request( self.mq_client, RequestType.RETRIEVE, - [keys, self.instance_id, block_ids, event.ipc_handle()], + [ + all_keys, + self.instance_id, + block_ids, + event.ipc_handle(), + ], ).to_cuda_future() - self.retrieve_futures[request_ids[0]] = (future, request_ids[1:]) + self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:])) @_lmcache_nvtx_annotate def get_finished( - self, finished_req_ids: set[str] + self, finished_req_ids_from_engine: set[str] ) -> tuple[set[str] | None, set[str] | None]: + """ + Check and get the finished store and retrieve requests. + + Args: + finished_req_ids_from_engine: the set of request ids that are + reported as finished from the vLLM engine side. + + Returns: + A tuple of two sets: + - The first set contains the finished store request ids. The returned + store request ids MUST be seen before in the + `finished_req_ids_from_engine`. + - The second set contains the finished retrieve request ids. + + Notes: + When enabling async scheduling in vLLM, the same request ID may appear + multiple times in `finished_req_ids_from_engine`. The adapter should + take care of deduplicating the request IDs and only return the request + IDs that have not been returned before. + """ finished_stores = set() finished_retrieves = set() - for request_id, (future, other_reqs) in self.store_futures.items(): - if not future.query(): + for request_id, (s_future, other_reqs) in self.store_futures.items(): + if not s_future.query(): continue - result = future.result() + s_result = s_future.result() finished_stores.add(request_id) finished_stores.update(other_reqs) - if not result: + if not s_result: # TODO: add error handling here logger.error( "Something went wrong when processing the " @@ -311,21 +566,21 @@ class LMCacheMPWorkerAdapter: request_id, ) - for request_id, (future, other_reqs) in self.retrieve_futures.items(): - if not future.query(): + for request_id, (r_future, other_reqs) in self.retrieve_futures.items(): + if not r_future.query(): continue - result = future.result() + r_result = r_future.result() finished_retrieves.add(request_id) finished_retrieves.update(other_reqs) - if not all(result): + if not all(r_result): # TODO: add error handing here logger.error( "Something went wrong when processing the " "retrieve request for request_id=%s, result=%s", request_id, - result, + r_result, ) # Remove the finished requests from the tracking dicts @@ -338,7 +593,7 @@ class LMCacheMPWorkerAdapter: self.finished_stores.update(finished_stores) ret_stores = set() - for req_id in finished_req_ids: + for req_id in finished_req_ids_from_engine: if req_id in self.finished_stores or req_id in self.store_futures: self.previously_finished.add(req_id) else: @@ -357,7 +612,9 @@ class LMCacheMPWorkerAdapter: return self.blocks_in_chunk def shutdown(self): - # Unregister kv cache + """ + Shutdown the LMCache MP worker adapter + """ logger.info("Unregistering kv caches") send_lmcache_request( self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id] @@ -378,18 +635,32 @@ class LMCacheMPWorkerAdapter: return safe_finished_s - def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey: - """Convert a block hash to an IPC cache engine key""" + def _create_key( + self, + token_ids: list[int], + start: int = 0, + end: int = 0, + request_id: str | None = None, + ) -> IPCCacheEngineKey: + """Convert token IDs to an IPC cache engine key""" return IPCCacheEngineKey( model_name=self.model_name, world_size=self.world_size, worker_id=self.worker_id, - chunk_hash=block_hash, + token_ids=tuple(token_ids), + start=start, + end=end, + request_id=request_id, ) - def _block_hashes_to_keys( - self, block_hashes: list[bytes] - ) -> list[IPCCacheEngineKey]: - """Convert block hashes to IPC cache engine keys""" - s = striding_block_hashes(block_hashes, self.blocks_in_chunk) - return [self._create_key(block_hash) for block_hash in s] + def _create_hash_key( + self, chunk_hash: bytes, request_id: str | None = None + ) -> IPCCacheEngineKey: + """Create a hash-mode IPC cache engine key""" + return IPCCacheEngineKey( + model_name=self.model_name, + world_size=self.world_size, + worker_id=self.worker_id, + chunk_hash=chunk_hash, + request_id=request_id, + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py index b542265dd..0379011e7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py @@ -3,7 +3,7 @@ import enum from collections.abc import Iterable from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal import torch import zmq @@ -130,12 +130,6 @@ def create_worker_adapter( ) -def convert_block_hashes_to_bytes( - block_hashes: list["BlockHash"], -) -> list[bytes]: - return cast(list[bytes], block_hashes) - - class LMCacheMPRequestState(enum.Enum): """ State machine: @@ -266,6 +260,7 @@ class LMCacheMPRequestMetadata: Args: tracker: The request tracker to generate the metadata from. blocks_in_chunk: the number of blocks in a LMCache data chunk + vllm_block_size: the block size used in vLLM """ # Store the blocks that has block hashes # NOTE: the invariant here is that `num_stored_blocks` should @@ -282,15 +277,21 @@ class LMCacheMPRequestMetadata: if num_chunks >= 1: start = tracker.num_stored_blocks end = start + num_chunks * blocks_in_chunk - block_hashes = convert_block_hashes_to_bytes( - tracker.block_hashes[start:end] - ) block_ids = tracker.allocated_block_ids[start:end] + start_token_idx = start * vllm_block_size + end_token_idx = end * vllm_block_size + token_ids = list(tracker.all_token_ids) + op = LoadStoreOp( + token_ids=token_ids, + block_ids=block_ids, + start=start_token_idx, + end=end_token_idx, + ) ret = LMCacheMPRequestMetadata( request_id=tracker.request_id, direction="STORE", - op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids), + op=op, ) # Update the request tracker @@ -303,6 +304,7 @@ class LMCacheMPRequestMetadata: def GetRetrieveMetadata( tracker: LMCacheMPRequestTracker, blocks_in_chunk: int, + vllm_block_size: int, ) -> "LMCacheMPRequestMetadata | None": """ Generate the retrieve metadata for the current request tracker. @@ -310,6 +312,7 @@ class LMCacheMPRequestMetadata: Args: tracker: The request tracker to generate the metadata from. blocks_in_chunk: the number of blocks in a LMCache data chunk + vllm_block_size: the block size used in vLLM """ if not tracker.is_ready_for_retrieving(): return None @@ -330,15 +333,21 @@ class LMCacheMPRequestMetadata: "number of LMCache hit blocks. " ) if end > start: - block_hashes = convert_block_hashes_to_bytes( - tracker.block_hashes[start:end] - ) block_ids = tracker.allocated_block_ids[start:end] + start_token_idx = start * vllm_block_size + end_token_idx = end * vllm_block_size + token_ids = list(tracker.all_token_ids) + op = LoadStoreOp( + token_ids=token_ids, + block_ids=block_ids, + start=start_token_idx, + end=end_token_idx, + ) ret = LMCacheMPRequestMetadata( request_id=tracker.request_id, direction="RETRIEVE", - op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids), + op=op, ) return ret @@ -643,7 +652,8 @@ class LMCacheMPConnector(KVConnectorBase_V1): return 0, False self.scheduler_adapter.maybe_submit_lookup_request( - request.request_id, convert_block_hashes_to_bytes(request.block_hashes) + request.request_id, + token_ids=list(request.all_token_ids), ) ret = self.scheduler_adapter.check_lookup_result(request.request_id) @@ -766,6 +776,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): """ # Clean up request tracker to prevent memory leak self._cleanup_request_tracker(request.request_id) + # Notify LMCache to end the session for this request + self.scheduler_adapter.end_session(request.request_id) + return True, None def take_events(self) -> Iterable["KVCacheEvent"]: @@ -846,7 +859,9 @@ class LMCacheMPConnector(KVConnectorBase_V1): if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD: continue r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata( - request_tracker, blocks_per_chunk + request_tracker, + blocks_per_chunk, + vllm_block_size=self.vllm_block_size, ) if r_metadata is not None: metadata.add_request_metadata(r_metadata)