[LMCache] Token Base IPC API (#34175)

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
Yuwei An
2026-02-09 17:18:42 -08:00
committed by GitHub
parent 13397841ab
commit e94ec59733
2 changed files with 376 additions and 90 deletions

View File

@@ -20,16 +20,42 @@ from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
logger = init_logger(__name__)
def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache:
def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache:
logger.info("KV caches keys are %s", list(kv_caches.keys()))
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
def striding_block_hashes(
block_hashes: list[bytes], blocks_in_chunk: int
) -> Iterable[bytes]:
"""Extract chunk-level hashes from block hashes by striding.
In hash-based vLLM, each vLLM block has its own hash. LMCache chunks
span ``blocks_in_chunk`` consecutive blocks. The representative hash
for a chunk is the hash of the **last** block in that chunk (because
each block hash already encodes its prefix). So we start at index
``blocks_in_chunk - 1`` and stride by ``blocks_in_chunk``.
"""
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
def send_lmcache_request(
mq_client: MessageQueueClient,
request_type: RequestType,
payloads: list[Any],
) -> MessagingFuture[Any]:
"""
Helper function to send the request to the LMCache multiprocess server
Args:
mq_client: The LMCache multiprocess mode message queue client
request_type: The request type
payloads: The request payloads
Returns:
A messaging future for the request
"""
future = mq_client.submit_request(
request_type, payloads, get_response_class(request_type)
)
@@ -39,40 +65,44 @@ def send_lmcache_request(
def get_lmcache_chunk_size(
mq_client: MessageQueueClient,
) -> int:
"""
Helper function to get the LMCache chunk size from the server
Args:
mq_client: The LMCache multiprocess mode message queue client
Returns:
An integer representing the LMCache chunk size
"""
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
chunk_size = future.result()
return chunk_size
def striding_block_hashes(
block_hashes: list[bytes],
blocks_in_chunk,
) -> Iterable[bytes]:
"""Striding the block hashes to get the block hashes for each chunk.
For example, if blocks_in_chunk is 16, then we will get the block hashes
for the 16th, 32nd, 48th, ... blocks.
"""
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
@dataclass
class LoadStoreOp:
block_hashes: list[bytes]
block_ids: list[int]
"""Block ids for the load/store operation"""
token_ids: list[int] | None = None
"""Token IDs for the load/store operation (token mode)"""
block_hashes: list[bytes] | None = None
"""Block hashes for the load/store operation (hash mode)"""
start: int = 0
"""Start token index (token mode only)"""
end: int = 0
"""End token index (token mode only)"""
def __len__(self) -> int:
return len(self.block_hashes)
def __post_init__(self):
assert len(self.block_hashes) == len(self.block_ids), (
"The number of block hashes should be equal to the number of block ids "
f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
)
return len(self.block_ids)
StoreResult = bool
RetrieveResult = list[bool]
LookupResult = list[bool]
LookupResult = int
class LMCacheMPSchedulerAdapter:
@@ -95,10 +125,6 @@ class LMCacheMPSchedulerAdapter:
kv_rank: The kv rank used for LMCache keys
vllm_block_size: The block size used in vLLM
"""
logger.warning(
"Importing LMCacheMPSchedulerAdapter is deprecated. "
"Please update your LMCache to the latest version."
)
self.mq_client = MessageQueueClient(server_url, context)
# Request futures
@@ -116,22 +142,89 @@ class LMCacheMPSchedulerAdapter:
self.blocks_in_chunk = self.chunk_size // vllm_block_size
@_lmcache_nvtx_annotate
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
def maybe_submit_lookup_request(
self,
request_id: str,
block_hashes: list[bytes] | None = None,
token_ids: list[int] | None = None,
) -> None:
"""
Submit a new lookup request to LMCache if there is no ongoing request.
Supports both token-based and hash-based vLLM:
- token_ids: token IDs (token-based vLLM) -> single token-mode key
- block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys
Exactly one of block_hashes or token_ids must be provided.
Args:
request_id: The ID of the lookup request. The same ID indicates it's
from the same request
block_hashes: Block hashes to lookup from LMCache (hash mode)
token_ids: Token IDs to lookup from LMCache (token mode)
Returns:
None
Notes:
This function will have a side-effect: submitting a look up request to
LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
for later retrieve operations.
In the meantime, this function will record the lookup request, and the
status of the look up request can be checked by `check_lookup_result`.
"""
if request_id in self.lookup_futures:
# Skip if there is already a lookup request
return
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
keys = [self._create_key(block_hash) for block_hash in s]
assert (block_hashes is None) != (token_ids is None), (
"Exactly one of block_hashes or token_ids must be provided"
)
if block_hashes is not None:
# Hash mode: stride block hashes -> N hash-mode keys
chunk_hashes = list(
striding_block_hashes(block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
]
else:
# Token mode: truncate to chunk-aligned length
assert token_ids is not None
aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size
if aligned_end == 0:
return
keys = [
self._create_key(
token_ids,
start=0,
end=aligned_end,
request_id=request_id,
).no_worker_id_version()
]
future = send_lmcache_request(
self.mq_client,
RequestType.LOOKUP,
[keys, True],
[keys],
)
self.lookup_futures[request_id] = future
@_lmcache_nvtx_annotate
def check_lookup_result(self, request_id: str) -> int | None:
"""
Check the result of a previously submitted lookup request.
Args:
request_id: The ID of the lookup request submitted in
`maybe_submit_lookup_request`
Returns:
An integer representing the total number of tokens matched
in LMCache (prefix matching), or
None if the lookup request is not finished yet.
"""
assert request_id in self.lookup_futures, (
f"Lookup request for request_id={request_id} has not been submitted"
)
@@ -141,7 +234,7 @@ class LMCacheMPSchedulerAdapter:
return None
result = future.result()
num_chunks = sum(result)
num_chunks = result
return num_chunks * self.chunk_size
def num_blocks_per_chunk(self) -> int:
@@ -159,14 +252,47 @@ class LMCacheMPSchedulerAdapter:
"""
self.lookup_futures.pop(request_id, None)
def end_session(self, request_id: str) -> None:
"""
Notify LMCache server to remove the session for a finished request.
Args:
request_id: The ID of the finished request.
"""
send_lmcache_request(
self.mq_client,
RequestType.END_SESSION,
[request_id],
)
# Helper functions
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key"""
def _create_key(
self,
token_ids: list[int],
start: int = 0,
end: int = 0,
request_id: str | None = None,
) -> IPCCacheEngineKey:
"""Convert token IDs to an IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=block_hash,
token_ids=tuple(token_ids),
start=start,
end=end,
request_id=request_id,
)
def _create_hash_key(
self, chunk_hash: bytes, request_id: str | None = None
) -> IPCCacheEngineKey:
"""Create a hash-mode IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=None,
chunk_hash=chunk_hash,
request_id=request_id,
)
@@ -180,10 +306,6 @@ class LMCacheMPWorkerAdapter:
kv_rank: int,
vllm_block_size: int,
):
logger.warning(
"Importing LMCacheMPWorkerAdapter is deprecated. "
"Please update your LMCache to the latest version."
)
self.mq_client = MessageQueueClient(server_url, context)
# Instance id for GPU worker
@@ -201,7 +323,10 @@ class LMCacheMPWorkerAdapter:
str, tuple[MessagingFuture[RetrieveResult], list[str]]
] = {}
# The store requests that have finished execution in LMCache
self.finished_stores: set[str] = set()
# The finished request ids that are passed via vLLM and also
# have corresponding store requests submitted to LMCache before
self.previously_finished: set[str] = set()
self.model_name = model_name
@@ -215,7 +340,14 @@ class LMCacheMPWorkerAdapter:
)
self.blocks_in_chunk = chunk_size // vllm_block_size
def register_kv_caches(self, kv_caches: dict[str, KVCache]):
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""
Register the kv caches with LMCache server
Args:
kv_caches: A dict of kv caches to register. The keys are the
layer names and the values are the corresponding tensors.
"""
# Register kv cache and send the request
self.kv_caches = kv_caches
logger.info("Registering kv caches")
@@ -230,7 +362,29 @@ class LMCacheMPWorkerAdapter:
def submit_store_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
keys = self._block_hashes_to_keys(op.block_hashes)
"""
Submit a KV cache store request to LMCache
Args:
request_id: The ID of the request
op: The LoadStoreOp describing the store operation.
event: The CUDA event that is recorded after the current
model inference step
"""
if op.block_hashes is not None:
# Hash mode
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
]
else:
# Token mode
assert op.token_ids is not None
keys = [
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
]
future = send_lmcache_request(
self.mq_client,
RequestType.STORE,
@@ -242,7 +396,29 @@ class LMCacheMPWorkerAdapter:
def submit_retrieve_request(
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
):
keys = self._block_hashes_to_keys(op.block_hashes)
"""
Submit a KV cache retrieve request to LMCache
Args:
request_id: The ID of the request
op: The LoadStoreOp describing the retrieve operation.
event: The CUDA event that is recorded after the current
model inference step
"""
if op.block_hashes is not None:
# Hash mode
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
]
else:
# Token mode
assert op.token_ids is not None
keys = [
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
]
future = send_lmcache_request(
self.mq_client,
RequestType.RETRIEVE,
@@ -257,17 +433,47 @@ class LMCacheMPWorkerAdapter:
ops: list[LoadStoreOp],
event: torch.cuda.Event,
):
keys = []
block_ids = []
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
"""
Submit a batched store request to LMCache
Args:
request_ids: The IDs of the requests
ops: The LoadStoreOps describing the store operations. Should have
the same length as request_ids
event: The CUDA event that is recorded after the current
model inference step
"""
all_keys: list[IPCCacheEngineKey] = []
block_ids: list[int] = []
for request_id, op in zip(request_ids, ops, strict=False):
if op.block_hashes is not None:
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id)
for ch in chunk_hashes
]
all_keys.extend(keys)
else:
assert op.token_ids is not None
all_keys.append(
self._create_key(
op.token_ids, op.start, op.end, request_id=request_id
)
)
block_ids.extend(op.block_ids)
future = send_lmcache_request(
self.mq_client,
RequestType.STORE,
[keys, self.instance_id, block_ids, event.ipc_handle()],
[
all_keys,
self.instance_id,
block_ids,
event.ipc_handle(),
],
).to_cuda_future()
self.store_futures[request_ids[0]] = (future, request_ids[1:])
self.store_futures[request_ids[0]] = (future, list(request_ids[1:]))
@_lmcache_nvtx_annotate
def batched_submit_retrieve_requests(
@@ -276,34 +482,83 @@ class LMCacheMPWorkerAdapter:
ops: list[LoadStoreOp],
event: torch.cuda.Event,
):
keys = []
block_ids = []
"""
Submit a batched retrieve request to LMCache
for op in ops:
keys.extend(self._block_hashes_to_keys(op.block_hashes))
Args:
request_ids: The IDs of the requests
ops: The LoadStoreOps describing the retrieve operations. Should have
the same length as request_ids
event: The CUDA event that is recorded after the current
model inference step
"""
all_keys: list[IPCCacheEngineKey] = []
block_ids: list[int] = []
for request_id, op in zip(request_ids, ops, strict=False):
if op.block_hashes is not None:
chunk_hashes = list(
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
)
keys = [
self._create_hash_key(ch, request_id=request_id)
for ch in chunk_hashes
]
all_keys.extend(keys)
else:
assert op.token_ids is not None
all_keys.append(
self._create_key(
op.token_ids, op.start, op.end, request_id=request_id
)
)
block_ids.extend(op.block_ids)
future = send_lmcache_request(
self.mq_client,
RequestType.RETRIEVE,
[keys, self.instance_id, block_ids, event.ipc_handle()],
[
all_keys,
self.instance_id,
block_ids,
event.ipc_handle(),
],
).to_cuda_future()
self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])
self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:]))
@_lmcache_nvtx_annotate
def get_finished(
self, finished_req_ids: set[str]
self, finished_req_ids_from_engine: set[str]
) -> tuple[set[str] | None, set[str] | None]:
"""
Check and get the finished store and retrieve requests.
Args:
finished_req_ids_from_engine: the set of request ids that are
reported as finished from the vLLM engine side.
Returns:
A tuple of two sets:
- The first set contains the finished store request ids. The returned
store request ids MUST be seen before in the
`finished_req_ids_from_engine`.
- The second set contains the finished retrieve request ids.
Notes:
When enabling async scheduling in vLLM, the same request ID may appear
multiple times in `finished_req_ids_from_engine`. The adapter should
take care of deduplicating the request IDs and only return the request
IDs that have not been returned before.
"""
finished_stores = set()
finished_retrieves = set()
for request_id, (future, other_reqs) in self.store_futures.items():
if not future.query():
for request_id, (s_future, other_reqs) in self.store_futures.items():
if not s_future.query():
continue
result = future.result()
s_result = s_future.result()
finished_stores.add(request_id)
finished_stores.update(other_reqs)
if not result:
if not s_result:
# TODO: add error handling here
logger.error(
"Something went wrong when processing the "
@@ -311,21 +566,21 @@ class LMCacheMPWorkerAdapter:
request_id,
)
for request_id, (future, other_reqs) in self.retrieve_futures.items():
if not future.query():
for request_id, (r_future, other_reqs) in self.retrieve_futures.items():
if not r_future.query():
continue
result = future.result()
r_result = r_future.result()
finished_retrieves.add(request_id)
finished_retrieves.update(other_reqs)
if not all(result):
if not all(r_result):
# TODO: add error handing here
logger.error(
"Something went wrong when processing the "
"retrieve request for request_id=%s, result=%s",
request_id,
result,
r_result,
)
# Remove the finished requests from the tracking dicts
@@ -338,7 +593,7 @@ class LMCacheMPWorkerAdapter:
self.finished_stores.update(finished_stores)
ret_stores = set()
for req_id in finished_req_ids:
for req_id in finished_req_ids_from_engine:
if req_id in self.finished_stores or req_id in self.store_futures:
self.previously_finished.add(req_id)
else:
@@ -357,7 +612,9 @@ class LMCacheMPWorkerAdapter:
return self.blocks_in_chunk
def shutdown(self):
# Unregister kv cache
"""
Shutdown the LMCache MP worker adapter
"""
logger.info("Unregistering kv caches")
send_lmcache_request(
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
@@ -378,18 +635,32 @@ class LMCacheMPWorkerAdapter:
return safe_finished_s
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
"""Convert a block hash to an IPC cache engine key"""
def _create_key(
self,
token_ids: list[int],
start: int = 0,
end: int = 0,
request_id: str | None = None,
) -> IPCCacheEngineKey:
"""Convert token IDs to an IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=block_hash,
token_ids=tuple(token_ids),
start=start,
end=end,
request_id=request_id,
)
def _block_hashes_to_keys(
self, block_hashes: list[bytes]
) -> list[IPCCacheEngineKey]:
"""Convert block hashes to IPC cache engine keys"""
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
return [self._create_key(block_hash) for block_hash in s]
def _create_hash_key(
self, chunk_hash: bytes, request_id: str | None = None
) -> IPCCacheEngineKey:
"""Create a hash-mode IPC cache engine key"""
return IPCCacheEngineKey(
model_name=self.model_name,
world_size=self.world_size,
worker_id=self.worker_id,
chunk_hash=chunk_hash,
request_id=request_id,
)

View File

@@ -3,7 +3,7 @@
import enum
from collections.abc import Iterable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Any, Literal
import torch
import zmq
@@ -130,12 +130,6 @@ def create_worker_adapter(
)
def convert_block_hashes_to_bytes(
block_hashes: list["BlockHash"],
) -> list[bytes]:
return cast(list[bytes], block_hashes)
class LMCacheMPRequestState(enum.Enum):
"""
State machine:
@@ -266,6 +260,7 @@ class LMCacheMPRequestMetadata:
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
vllm_block_size: the block size used in vLLM
"""
# Store the blocks that has block hashes
# NOTE: the invariant here is that `num_stored_blocks` should
@@ -282,15 +277,21 @@ class LMCacheMPRequestMetadata:
if num_chunks >= 1:
start = tracker.num_stored_blocks
end = start + num_chunks * blocks_in_chunk
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end]
start_token_idx = start * vllm_block_size
end_token_idx = end * vllm_block_size
token_ids = list(tracker.all_token_ids)
op = LoadStoreOp(
token_ids=token_ids,
block_ids=block_ids,
start=start_token_idx,
end=end_token_idx,
)
ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id,
direction="STORE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
op=op,
)
# Update the request tracker
@@ -303,6 +304,7 @@ class LMCacheMPRequestMetadata:
def GetRetrieveMetadata(
tracker: LMCacheMPRequestTracker,
blocks_in_chunk: int,
vllm_block_size: int,
) -> "LMCacheMPRequestMetadata | None":
"""
Generate the retrieve metadata for the current request tracker.
@@ -310,6 +312,7 @@ class LMCacheMPRequestMetadata:
Args:
tracker: The request tracker to generate the metadata from.
blocks_in_chunk: the number of blocks in a LMCache data chunk
vllm_block_size: the block size used in vLLM
"""
if not tracker.is_ready_for_retrieving():
return None
@@ -330,15 +333,21 @@ class LMCacheMPRequestMetadata:
"number of LMCache hit blocks. "
)
if end > start:
block_hashes = convert_block_hashes_to_bytes(
tracker.block_hashes[start:end]
)
block_ids = tracker.allocated_block_ids[start:end]
start_token_idx = start * vllm_block_size
end_token_idx = end * vllm_block_size
token_ids = list(tracker.all_token_ids)
op = LoadStoreOp(
token_ids=token_ids,
block_ids=block_ids,
start=start_token_idx,
end=end_token_idx,
)
ret = LMCacheMPRequestMetadata(
request_id=tracker.request_id,
direction="RETRIEVE",
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
op=op,
)
return ret
@@ -643,7 +652,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
return 0, False
self.scheduler_adapter.maybe_submit_lookup_request(
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
request.request_id,
token_ids=list(request.all_token_ids),
)
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
@@ -766,6 +776,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
"""
# Clean up request tracker to prevent memory leak
self._cleanup_request_tracker(request.request_id)
# Notify LMCache to end the session for this request
self.scheduler_adapter.end_session(request.request_id)
return True, None
def take_events(self) -> Iterable["KVCacheEvent"]:
@@ -846,7 +859,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
continue
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
request_tracker, blocks_per_chunk
request_tracker,
blocks_per_chunk,
vllm_block_size=self.vllm_block_size,
)
if r_metadata is not None:
metadata.add_request_metadata(r_metadata)