[LMCache] Token Base IPC API (#34175)
Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
@@ -20,16 +20,42 @@ from lmcache.v1.multiprocess.protocol import RequestType, get_response_class
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def wrap_kv_caches(kv_caches: dict[str, KVCache]) -> KVCache:
|
||||
def wrap_kv_caches(kv_caches: dict[str, torch.Tensor]) -> KVCache:
|
||||
logger.info("KV caches keys are %s", list(kv_caches.keys()))
|
||||
return [CudaIPCWrapper(tensor) for tensor in kv_caches.values()]
|
||||
|
||||
|
||||
def striding_block_hashes(
|
||||
block_hashes: list[bytes], blocks_in_chunk: int
|
||||
) -> Iterable[bytes]:
|
||||
"""Extract chunk-level hashes from block hashes by striding.
|
||||
|
||||
In hash-based vLLM, each vLLM block has its own hash. LMCache chunks
|
||||
span ``blocks_in_chunk`` consecutive blocks. The representative hash
|
||||
for a chunk is the hash of the **last** block in that chunk (because
|
||||
each block hash already encodes its prefix). So we start at index
|
||||
``blocks_in_chunk - 1`` and stride by ``blocks_in_chunk``.
|
||||
"""
|
||||
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
|
||||
|
||||
|
||||
def send_lmcache_request(
|
||||
mq_client: MessageQueueClient,
|
||||
request_type: RequestType,
|
||||
payloads: list[Any],
|
||||
) -> MessagingFuture[Any]:
|
||||
"""
|
||||
Helper function to send the request to the LMCache multiprocess server
|
||||
|
||||
Args:
|
||||
mq_client: The LMCache multiprocess mode message queue client
|
||||
request_type: The request type
|
||||
payloads: The request payloads
|
||||
|
||||
Returns:
|
||||
A messaging future for the request
|
||||
"""
|
||||
|
||||
future = mq_client.submit_request(
|
||||
request_type, payloads, get_response_class(request_type)
|
||||
)
|
||||
@@ -39,40 +65,44 @@ def send_lmcache_request(
|
||||
def get_lmcache_chunk_size(
|
||||
mq_client: MessageQueueClient,
|
||||
) -> int:
|
||||
"""
|
||||
Helper function to get the LMCache chunk size from the server
|
||||
|
||||
Args:
|
||||
mq_client: The LMCache multiprocess mode message queue client
|
||||
|
||||
Returns:
|
||||
An integer representing the LMCache chunk size
|
||||
"""
|
||||
future = send_lmcache_request(mq_client, RequestType.GET_CHUNK_SIZE, [])
|
||||
chunk_size = future.result()
|
||||
return chunk_size
|
||||
|
||||
|
||||
def striding_block_hashes(
|
||||
block_hashes: list[bytes],
|
||||
blocks_in_chunk,
|
||||
) -> Iterable[bytes]:
|
||||
"""Striding the block hashes to get the block hashes for each chunk.
|
||||
For example, if blocks_in_chunk is 16, then we will get the block hashes
|
||||
for the 16th, 32nd, 48th, ... blocks.
|
||||
"""
|
||||
return islice(block_hashes, blocks_in_chunk - 1, None, blocks_in_chunk)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadStoreOp:
|
||||
block_hashes: list[bytes]
|
||||
block_ids: list[int]
|
||||
"""Block ids for the load/store operation"""
|
||||
|
||||
token_ids: list[int] | None = None
|
||||
"""Token IDs for the load/store operation (token mode)"""
|
||||
|
||||
block_hashes: list[bytes] | None = None
|
||||
"""Block hashes for the load/store operation (hash mode)"""
|
||||
|
||||
start: int = 0
|
||||
"""Start token index (token mode only)"""
|
||||
|
||||
end: int = 0
|
||||
"""End token index (token mode only)"""
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.block_hashes)
|
||||
|
||||
def __post_init__(self):
|
||||
assert len(self.block_hashes) == len(self.block_ids), (
|
||||
"The number of block hashes should be equal to the number of block ids "
|
||||
f"But got {len(self.block_hashes)} and {len(self.block_ids)}"
|
||||
)
|
||||
return len(self.block_ids)
|
||||
|
||||
|
||||
StoreResult = bool
|
||||
RetrieveResult = list[bool]
|
||||
LookupResult = list[bool]
|
||||
LookupResult = int
|
||||
|
||||
|
||||
class LMCacheMPSchedulerAdapter:
|
||||
@@ -95,10 +125,6 @@ class LMCacheMPSchedulerAdapter:
|
||||
kv_rank: The kv rank used for LMCache keys
|
||||
vllm_block_size: The block size used in vLLM
|
||||
"""
|
||||
logger.warning(
|
||||
"Importing LMCacheMPSchedulerAdapter is deprecated. "
|
||||
"Please update your LMCache to the latest version."
|
||||
)
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Request futures
|
||||
@@ -116,22 +142,89 @@ class LMCacheMPSchedulerAdapter:
|
||||
self.blocks_in_chunk = self.chunk_size // vllm_block_size
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def maybe_submit_lookup_request(self, request_id: str, block_hashes: list[bytes]):
|
||||
def maybe_submit_lookup_request(
|
||||
self,
|
||||
request_id: str,
|
||||
block_hashes: list[bytes] | None = None,
|
||||
token_ids: list[int] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Submit a new lookup request to LMCache if there is no ongoing request.
|
||||
|
||||
Supports both token-based and hash-based vLLM:
|
||||
- token_ids: token IDs (token-based vLLM) -> single token-mode key
|
||||
- block_hashes: block hashes (hash-based vLLM) -> strided hash-mode keys
|
||||
|
||||
Exactly one of block_hashes or token_ids must be provided.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the lookup request. The same ID indicates it's
|
||||
from the same request
|
||||
block_hashes: Block hashes to lookup from LMCache (hash mode)
|
||||
token_ids: Token IDs to lookup from LMCache (token mode)
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
Notes:
|
||||
This function will have a side-effect: submitting a look up request to
|
||||
LMCache, which will essentially 'lock' the KV cache chunks in the LMCache
|
||||
for later retrieve operations.
|
||||
In the meantime, this function will record the lookup request, and the
|
||||
status of the look up request can be checked by `check_lookup_result`.
|
||||
"""
|
||||
if request_id in self.lookup_futures:
|
||||
# Skip if there is already a lookup request
|
||||
return
|
||||
|
||||
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||
keys = [self._create_key(block_hash) for block_hash in s]
|
||||
assert (block_hashes is None) != (token_ids is None), (
|
||||
"Exactly one of block_hashes or token_ids must be provided"
|
||||
)
|
||||
|
||||
if block_hashes is not None:
|
||||
# Hash mode: stride block hashes -> N hash-mode keys
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode: truncate to chunk-aligned length
|
||||
assert token_ids is not None
|
||||
aligned_end = (len(token_ids) // self.chunk_size) * self.chunk_size
|
||||
if aligned_end == 0:
|
||||
return
|
||||
keys = [
|
||||
self._create_key(
|
||||
token_ids,
|
||||
start=0,
|
||||
end=aligned_end,
|
||||
request_id=request_id,
|
||||
).no_worker_id_version()
|
||||
]
|
||||
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.LOOKUP,
|
||||
[keys, True],
|
||||
[keys],
|
||||
)
|
||||
self.lookup_futures[request_id] = future
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def check_lookup_result(self, request_id: str) -> int | None:
|
||||
"""
|
||||
Check the result of a previously submitted lookup request.
|
||||
|
||||
Args:
|
||||
request_id: The ID of the lookup request submitted in
|
||||
`maybe_submit_lookup_request`
|
||||
|
||||
Returns:
|
||||
An integer representing the total number of tokens matched
|
||||
in LMCache (prefix matching), or
|
||||
None if the lookup request is not finished yet.
|
||||
"""
|
||||
assert request_id in self.lookup_futures, (
|
||||
f"Lookup request for request_id={request_id} has not been submitted"
|
||||
)
|
||||
@@ -141,7 +234,7 @@ class LMCacheMPSchedulerAdapter:
|
||||
return None
|
||||
|
||||
result = future.result()
|
||||
num_chunks = sum(result)
|
||||
num_chunks = result
|
||||
return num_chunks * self.chunk_size
|
||||
|
||||
def num_blocks_per_chunk(self) -> int:
|
||||
@@ -159,14 +252,47 @@ class LMCacheMPSchedulerAdapter:
|
||||
"""
|
||||
self.lookup_futures.pop(request_id, None)
|
||||
|
||||
def end_session(self, request_id: str) -> None:
|
||||
"""
|
||||
Notify LMCache server to remove the session for a finished request.
|
||||
Args:
|
||||
request_id: The ID of the finished request.
|
||||
"""
|
||||
send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.END_SESSION,
|
||||
[request_id],
|
||||
)
|
||||
|
||||
# Helper functions
|
||||
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
|
||||
"""Convert a block hash to an IPC cache engine key"""
|
||||
def _create_key(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
request_id: str | None = None,
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Convert token IDs to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
chunk_hash=block_hash,
|
||||
token_ids=tuple(token_ids),
|
||||
start=start,
|
||||
end=end,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _create_hash_key(
|
||||
self, chunk_hash: bytes, request_id: str | None = None
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Create a hash-mode IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=None,
|
||||
chunk_hash=chunk_hash,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -180,10 +306,6 @@ class LMCacheMPWorkerAdapter:
|
||||
kv_rank: int,
|
||||
vllm_block_size: int,
|
||||
):
|
||||
logger.warning(
|
||||
"Importing LMCacheMPWorkerAdapter is deprecated. "
|
||||
"Please update your LMCache to the latest version."
|
||||
)
|
||||
self.mq_client = MessageQueueClient(server_url, context)
|
||||
|
||||
# Instance id for GPU worker
|
||||
@@ -201,7 +323,10 @@ class LMCacheMPWorkerAdapter:
|
||||
str, tuple[MessagingFuture[RetrieveResult], list[str]]
|
||||
] = {}
|
||||
|
||||
# The store requests that have finished execution in LMCache
|
||||
self.finished_stores: set[str] = set()
|
||||
# The finished request ids that are passed via vLLM and also
|
||||
# have corresponding store requests submitted to LMCache before
|
||||
self.previously_finished: set[str] = set()
|
||||
|
||||
self.model_name = model_name
|
||||
@@ -215,7 +340,14 @@ class LMCacheMPWorkerAdapter:
|
||||
)
|
||||
self.blocks_in_chunk = chunk_size // vllm_block_size
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, KVCache]):
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""
|
||||
Register the kv caches with LMCache server
|
||||
|
||||
Args:
|
||||
kv_caches: A dict of kv caches to register. The keys are the
|
||||
layer names and the values are the corresponding tensors.
|
||||
"""
|
||||
# Register kv cache and send the request
|
||||
self.kv_caches = kv_caches
|
||||
logger.info("Registering kv caches")
|
||||
@@ -230,7 +362,29 @@ class LMCacheMPWorkerAdapter:
|
||||
def submit_store_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
keys = self._block_hashes_to_keys(op.block_hashes)
|
||||
"""
|
||||
Submit a KV cache store request to LMCache
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request
|
||||
op: The LoadStoreOp describing the store operation.
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
if op.block_hashes is not None:
|
||||
# Hash mode
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode
|
||||
assert op.token_ids is not None
|
||||
keys = [
|
||||
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
|
||||
]
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
@@ -242,7 +396,29 @@ class LMCacheMPWorkerAdapter:
|
||||
def submit_retrieve_request(
|
||||
self, request_id: str, op: LoadStoreOp, event: torch.cuda.Event
|
||||
):
|
||||
keys = self._block_hashes_to_keys(op.block_hashes)
|
||||
"""
|
||||
Submit a KV cache retrieve request to LMCache
|
||||
|
||||
Args:
|
||||
request_id: The ID of the request
|
||||
op: The LoadStoreOp describing the retrieve operation.
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
if op.block_hashes is not None:
|
||||
# Hash mode
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id) for ch in chunk_hashes
|
||||
]
|
||||
else:
|
||||
# Token mode
|
||||
assert op.token_ids is not None
|
||||
keys = [
|
||||
self._create_key(op.token_ids, op.start, op.end, request_id=request_id)
|
||||
]
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
@@ -257,17 +433,47 @@ class LMCacheMPWorkerAdapter:
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
keys = []
|
||||
block_ids = []
|
||||
for op in ops:
|
||||
keys.extend(self._block_hashes_to_keys(op.block_hashes))
|
||||
"""
|
||||
Submit a batched store request to LMCache
|
||||
|
||||
Args:
|
||||
request_ids: The IDs of the requests
|
||||
ops: The LoadStoreOps describing the store operations. Should have
|
||||
the same length as request_ids
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
all_keys: list[IPCCacheEngineKey] = []
|
||||
block_ids: list[int] = []
|
||||
for request_id, op in zip(request_ids, ops, strict=False):
|
||||
if op.block_hashes is not None:
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id)
|
||||
for ch in chunk_hashes
|
||||
]
|
||||
all_keys.extend(keys)
|
||||
else:
|
||||
assert op.token_ids is not None
|
||||
all_keys.append(
|
||||
self._create_key(
|
||||
op.token_ids, op.start, op.end, request_id=request_id
|
||||
)
|
||||
)
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.STORE,
|
||||
[keys, self.instance_id, block_ids, event.ipc_handle()],
|
||||
[
|
||||
all_keys,
|
||||
self.instance_id,
|
||||
block_ids,
|
||||
event.ipc_handle(),
|
||||
],
|
||||
).to_cuda_future()
|
||||
self.store_futures[request_ids[0]] = (future, request_ids[1:])
|
||||
self.store_futures[request_ids[0]] = (future, list(request_ids[1:]))
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def batched_submit_retrieve_requests(
|
||||
@@ -276,34 +482,83 @@ class LMCacheMPWorkerAdapter:
|
||||
ops: list[LoadStoreOp],
|
||||
event: torch.cuda.Event,
|
||||
):
|
||||
keys = []
|
||||
block_ids = []
|
||||
"""
|
||||
Submit a batched retrieve request to LMCache
|
||||
|
||||
for op in ops:
|
||||
keys.extend(self._block_hashes_to_keys(op.block_hashes))
|
||||
Args:
|
||||
request_ids: The IDs of the requests
|
||||
ops: The LoadStoreOps describing the retrieve operations. Should have
|
||||
the same length as request_ids
|
||||
event: The CUDA event that is recorded after the current
|
||||
model inference step
|
||||
"""
|
||||
all_keys: list[IPCCacheEngineKey] = []
|
||||
block_ids: list[int] = []
|
||||
for request_id, op in zip(request_ids, ops, strict=False):
|
||||
if op.block_hashes is not None:
|
||||
chunk_hashes = list(
|
||||
striding_block_hashes(op.block_hashes, self.blocks_in_chunk)
|
||||
)
|
||||
keys = [
|
||||
self._create_hash_key(ch, request_id=request_id)
|
||||
for ch in chunk_hashes
|
||||
]
|
||||
all_keys.extend(keys)
|
||||
else:
|
||||
assert op.token_ids is not None
|
||||
all_keys.append(
|
||||
self._create_key(
|
||||
op.token_ids, op.start, op.end, request_id=request_id
|
||||
)
|
||||
)
|
||||
block_ids.extend(op.block_ids)
|
||||
future = send_lmcache_request(
|
||||
self.mq_client,
|
||||
RequestType.RETRIEVE,
|
||||
[keys, self.instance_id, block_ids, event.ipc_handle()],
|
||||
[
|
||||
all_keys,
|
||||
self.instance_id,
|
||||
block_ids,
|
||||
event.ipc_handle(),
|
||||
],
|
||||
).to_cuda_future()
|
||||
self.retrieve_futures[request_ids[0]] = (future, request_ids[1:])
|
||||
self.retrieve_futures[request_ids[0]] = (future, list(request_ids[1:]))
|
||||
|
||||
@_lmcache_nvtx_annotate
|
||||
def get_finished(
|
||||
self, finished_req_ids: set[str]
|
||||
self, finished_req_ids_from_engine: set[str]
|
||||
) -> tuple[set[str] | None, set[str] | None]:
|
||||
"""
|
||||
Check and get the finished store and retrieve requests.
|
||||
|
||||
Args:
|
||||
finished_req_ids_from_engine: the set of request ids that are
|
||||
reported as finished from the vLLM engine side.
|
||||
|
||||
Returns:
|
||||
A tuple of two sets:
|
||||
- The first set contains the finished store request ids. The returned
|
||||
store request ids MUST be seen before in the
|
||||
`finished_req_ids_from_engine`.
|
||||
- The second set contains the finished retrieve request ids.
|
||||
|
||||
Notes:
|
||||
When enabling async scheduling in vLLM, the same request ID may appear
|
||||
multiple times in `finished_req_ids_from_engine`. The adapter should
|
||||
take care of deduplicating the request IDs and only return the request
|
||||
IDs that have not been returned before.
|
||||
"""
|
||||
finished_stores = set()
|
||||
finished_retrieves = set()
|
||||
for request_id, (future, other_reqs) in self.store_futures.items():
|
||||
if not future.query():
|
||||
for request_id, (s_future, other_reqs) in self.store_futures.items():
|
||||
if not s_future.query():
|
||||
continue
|
||||
|
||||
result = future.result()
|
||||
s_result = s_future.result()
|
||||
finished_stores.add(request_id)
|
||||
finished_stores.update(other_reqs)
|
||||
|
||||
if not result:
|
||||
if not s_result:
|
||||
# TODO: add error handling here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
@@ -311,21 +566,21 @@ class LMCacheMPWorkerAdapter:
|
||||
request_id,
|
||||
)
|
||||
|
||||
for request_id, (future, other_reqs) in self.retrieve_futures.items():
|
||||
if not future.query():
|
||||
for request_id, (r_future, other_reqs) in self.retrieve_futures.items():
|
||||
if not r_future.query():
|
||||
continue
|
||||
|
||||
result = future.result()
|
||||
r_result = r_future.result()
|
||||
finished_retrieves.add(request_id)
|
||||
finished_retrieves.update(other_reqs)
|
||||
|
||||
if not all(result):
|
||||
if not all(r_result):
|
||||
# TODO: add error handing here
|
||||
logger.error(
|
||||
"Something went wrong when processing the "
|
||||
"retrieve request for request_id=%s, result=%s",
|
||||
request_id,
|
||||
result,
|
||||
r_result,
|
||||
)
|
||||
|
||||
# Remove the finished requests from the tracking dicts
|
||||
@@ -338,7 +593,7 @@ class LMCacheMPWorkerAdapter:
|
||||
self.finished_stores.update(finished_stores)
|
||||
|
||||
ret_stores = set()
|
||||
for req_id in finished_req_ids:
|
||||
for req_id in finished_req_ids_from_engine:
|
||||
if req_id in self.finished_stores or req_id in self.store_futures:
|
||||
self.previously_finished.add(req_id)
|
||||
else:
|
||||
@@ -357,7 +612,9 @@ class LMCacheMPWorkerAdapter:
|
||||
return self.blocks_in_chunk
|
||||
|
||||
def shutdown(self):
|
||||
# Unregister kv cache
|
||||
"""
|
||||
Shutdown the LMCache MP worker adapter
|
||||
"""
|
||||
logger.info("Unregistering kv caches")
|
||||
send_lmcache_request(
|
||||
self.mq_client, RequestType.UNREGISTER_KV_CACHE, [self.instance_id]
|
||||
@@ -378,18 +635,32 @@ class LMCacheMPWorkerAdapter:
|
||||
|
||||
return safe_finished_s
|
||||
|
||||
def _create_key(self, block_hash: bytes) -> IPCCacheEngineKey:
|
||||
"""Convert a block hash to an IPC cache engine key"""
|
||||
def _create_key(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
start: int = 0,
|
||||
end: int = 0,
|
||||
request_id: str | None = None,
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Convert token IDs to an IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
chunk_hash=block_hash,
|
||||
token_ids=tuple(token_ids),
|
||||
start=start,
|
||||
end=end,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
def _block_hashes_to_keys(
|
||||
self, block_hashes: list[bytes]
|
||||
) -> list[IPCCacheEngineKey]:
|
||||
"""Convert block hashes to IPC cache engine keys"""
|
||||
s = striding_block_hashes(block_hashes, self.blocks_in_chunk)
|
||||
return [self._create_key(block_hash) for block_hash in s]
|
||||
def _create_hash_key(
|
||||
self, chunk_hash: bytes, request_id: str | None = None
|
||||
) -> IPCCacheEngineKey:
|
||||
"""Create a hash-mode IPC cache engine key"""
|
||||
return IPCCacheEngineKey(
|
||||
model_name=self.model_name,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
chunk_hash=chunk_hash,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import enum
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
@@ -130,12 +130,6 @@ def create_worker_adapter(
|
||||
)
|
||||
|
||||
|
||||
def convert_block_hashes_to_bytes(
|
||||
block_hashes: list["BlockHash"],
|
||||
) -> list[bytes]:
|
||||
return cast(list[bytes], block_hashes)
|
||||
|
||||
|
||||
class LMCacheMPRequestState(enum.Enum):
|
||||
"""
|
||||
State machine:
|
||||
@@ -266,6 +260,7 @@ class LMCacheMPRequestMetadata:
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
vllm_block_size: the block size used in vLLM
|
||||
"""
|
||||
# Store the blocks that has block hashes
|
||||
# NOTE: the invariant here is that `num_stored_blocks` should
|
||||
@@ -282,15 +277,21 @@ class LMCacheMPRequestMetadata:
|
||||
if num_chunks >= 1:
|
||||
start = tracker.num_stored_blocks
|
||||
end = start + num_chunks * blocks_in_chunk
|
||||
block_hashes = convert_block_hashes_to_bytes(
|
||||
tracker.block_hashes[start:end]
|
||||
)
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
start_token_idx = start * vllm_block_size
|
||||
end_token_idx = end * vllm_block_size
|
||||
token_ids = list(tracker.all_token_ids)
|
||||
op = LoadStoreOp(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
start=start_token_idx,
|
||||
end=end_token_idx,
|
||||
)
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="STORE",
|
||||
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
|
||||
op=op,
|
||||
)
|
||||
|
||||
# Update the request tracker
|
||||
@@ -303,6 +304,7 @@ class LMCacheMPRequestMetadata:
|
||||
def GetRetrieveMetadata(
|
||||
tracker: LMCacheMPRequestTracker,
|
||||
blocks_in_chunk: int,
|
||||
vllm_block_size: int,
|
||||
) -> "LMCacheMPRequestMetadata | None":
|
||||
"""
|
||||
Generate the retrieve metadata for the current request tracker.
|
||||
@@ -310,6 +312,7 @@ class LMCacheMPRequestMetadata:
|
||||
Args:
|
||||
tracker: The request tracker to generate the metadata from.
|
||||
blocks_in_chunk: the number of blocks in a LMCache data chunk
|
||||
vllm_block_size: the block size used in vLLM
|
||||
"""
|
||||
if not tracker.is_ready_for_retrieving():
|
||||
return None
|
||||
@@ -330,15 +333,21 @@ class LMCacheMPRequestMetadata:
|
||||
"number of LMCache hit blocks. "
|
||||
)
|
||||
if end > start:
|
||||
block_hashes = convert_block_hashes_to_bytes(
|
||||
tracker.block_hashes[start:end]
|
||||
)
|
||||
block_ids = tracker.allocated_block_ids[start:end]
|
||||
start_token_idx = start * vllm_block_size
|
||||
end_token_idx = end * vllm_block_size
|
||||
token_ids = list(tracker.all_token_ids)
|
||||
op = LoadStoreOp(
|
||||
token_ids=token_ids,
|
||||
block_ids=block_ids,
|
||||
start=start_token_idx,
|
||||
end=end_token_idx,
|
||||
)
|
||||
|
||||
ret = LMCacheMPRequestMetadata(
|
||||
request_id=tracker.request_id,
|
||||
direction="RETRIEVE",
|
||||
op=LoadStoreOp(block_hashes=block_hashes, block_ids=block_ids),
|
||||
op=op,
|
||||
)
|
||||
return ret
|
||||
|
||||
@@ -643,7 +652,8 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
return 0, False
|
||||
|
||||
self.scheduler_adapter.maybe_submit_lookup_request(
|
||||
request.request_id, convert_block_hashes_to_bytes(request.block_hashes)
|
||||
request.request_id,
|
||||
token_ids=list(request.all_token_ids),
|
||||
)
|
||||
|
||||
ret = self.scheduler_adapter.check_lookup_result(request.request_id)
|
||||
@@ -766,6 +776,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
# Clean up request tracker to prevent memory leak
|
||||
self._cleanup_request_tracker(request.request_id)
|
||||
# Notify LMCache to end the session for this request
|
||||
self.scheduler_adapter.end_session(request.request_id)
|
||||
|
||||
return True, None
|
||||
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
@@ -846,7 +859,9 @@ class LMCacheMPConnector(KVConnectorBase_V1):
|
||||
if request_tracker.state != LMCacheMPRequestState.WAITING_FOR_LOAD:
|
||||
continue
|
||||
r_metadata = LMCacheMPRequestMetadata.GetRetrieveMetadata(
|
||||
request_tracker, blocks_per_chunk
|
||||
request_tracker,
|
||||
blocks_per_chunk,
|
||||
vllm_block_size=self.vllm_block_size,
|
||||
)
|
||||
if r_metadata is not None:
|
||||
metadata.add_request_metadata(r_metadata)
|
||||
|
||||
Reference in New Issue
Block a user