[KVConnector] OffloadingConnector: Fix bug in handling of preemptions (#29870)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -25,6 +25,9 @@ The class provides the following primitives:
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
handle_preemptions() - called if there are preempted requests,
|
||||
before their blocks are overwritten
|
||||
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
"""
|
||||
Handle preempted requests BEFORE their blocks are overwritten.
|
||||
Needed for connectors which use async saves (e.g., OffloadingConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
|
||||
@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.handle_preemptions(preempted_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler:
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||
)
|
||||
self._reqs_to_load = {}
|
||||
|
||||
# NOTE (orozery): we should move this logic to update_connector_output
|
||||
# once KVConnectorOutput allows us to report completed transfers
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
block_hashes = self._reqs_being_stored.get(req_id)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
block_hashes.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
@@ -466,6 +479,17 @@ class OffloadingConnectorWorker:
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id in preempted_req_ids:
|
||||
job_ids = self._store_jobs.get(req_id)
|
||||
if job_ids:
|
||||
self.worker.wait(job_ids)
|
||||
|
||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
Reference in New Issue
Block a user