[KVConnector] OffloadingConnector: Fix bug in handling of preemptions (#29870)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-11 10:05:36 +02:00
committed by GitHub
parent bde57ab2ed
commit 4c16ba617f
7 changed files with 248 additions and 57 deletions

View File

@@ -25,6 +25,9 @@ The class provides the following primitives:
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
handle_preemptions() - called if there are preempted requests,
before their blocks are overwritten
start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC):
"""
return
def handle_preemptions(self, preempted_req_ids: set[str]):
"""
Handle preempted requests BEFORE their blocks are overwritten.
Needed for connectors which use async saves (e.g., OffloadingConnector)
"""
return
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""

View File

@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
def handle_preemptions(self, preempted_req_ids: set[str]):
assert self.connector_worker is not None
self.connector_worker.handle_preemptions(preempted_req_ids)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler:
reqs_to_store=self._get_reqs_to_store(scheduler_output),
)
self._reqs_to_load = {}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for req_id in scheduler_output.preempted_req_ids or ():
block_hashes = self._reqs_being_stored.get(req_id)
if block_hashes:
self.manager.complete_store(block_hashes)
block_hashes.clear()
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
@@ -466,6 +479,17 @@ class OffloadingConnectorWorker:
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)
def handle_preemptions(self, preempted_req_ids: set[str]):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id in preempted_req_ids:
job_ids = self._store_jobs.get(req_id)
if job_ids:
self.worker.wait(job_ids)
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)