From fcf0687b27b78c3b214504f5e9525f3f66a2d04a Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Wed, 18 Mar 2026 08:49:53 +0200 Subject: [PATCH] [kv_offload+HMA][0/N]: Support block-level preemption handling (#34805) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Or Ozeri Co-authored-by: Nicolò Lucchesi --- tests/v1/kv_connector/unit/test_multi_connector.py | 10 ++++++---- .../v1/kv_connector/unit/test_offloading_connector.py | 5 +---- vllm/distributed/kv_transfer/kv_connector/v1/base.py | 8 ++++---- .../kv_transfer/kv_connector/v1/multi_connector.py | 7 ++++--- .../kv_connector/v1/offloading_connector.py | 11 +++++++---- vllm/v1/worker/gpu/kv_connector.py | 3 +-- vllm/v1/worker/gpu_model_runner.py | 8 ++++---- 7 files changed, 27 insertions(+), 25 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_multi_connector.py b/tests/v1/kv_connector/unit/test_multi_connector.py index 6acc48629..671a80137 100644 --- a/tests/v1/kv_connector/unit/test_multi_connector.py +++ b/tests/v1/kv_connector/unit/test_multi_connector.py @@ -231,10 +231,11 @@ def test_multi_example_connector_consistency(): ] # First three events are from initialization (register_kv_caches, # set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events. - assert events["storage1-WORKER"][:7] == [ + assert events["storage1-WORKER"][:8] == [ "register_kv_caches", "set_host_xfer_buffer_ops", "get_handshake_metadata", + "handle_preemptions", "bind_connector_metadata", "start_load_kv", "wait_for_layer_load", @@ -246,10 +247,11 @@ def test_multi_example_connector_consistency(): "update_state_after_alloc num_blocks=[0] 0", "build_connector_meta", ] - assert events["storage2-WORKER"][:7] == [ + assert events["storage2-WORKER"][:8] == [ "register_kv_caches", "set_host_xfer_buffer_ops", "get_handshake_metadata", + "handle_preemptions", "bind_connector_metadata", "start_load_kv", "wait_for_layer_load", @@ -399,8 +401,8 @@ def test_multi_connector_handle_preemptions_integration(): # testing the delegation behavior of MultiConnector here. # The connector attribute contains the KV connector. assert scheduler.connector is not None, "Scheduler should have a connector" - preempted_req_ids = {"req-1", "req-2", "req-3"} - scheduler.connector.handle_preemptions(preempted_req_ids) + connector_md = scheduler.connector.build_connector_meta(scheduler.schedule()) + scheduler.connector.handle_preemptions(connector_md) # Verify both connectors received the handle_preemptions call events = get_connector_events() diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 893a5d8d4..c6365886f 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -363,10 +363,7 @@ class RequestRunner: assert kv_connector_metadata is not None assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) - if scheduler_output.preempted_req_ids: - self.worker_connector.handle_preemptions( - scheduler_output.preempted_req_ids - ) + self.worker_connector.handle_preemptions(kv_connector_metadata) self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.start_load_kv(self._dummy_ctx) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 2abbe6bf6..ef143cba7 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -25,8 +25,8 @@ The class provides the following primitives: Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. - handle_preemptions() - called if there are preempted requests, - before their blocks are overwritten + handle_preemptions() - called for handling preempted requests + or request evicted blocks before they are overwritten start_load_kv() - starts loading all KVs (maybe async) wait_for_layer_load() - blocks until layer i load is done @@ -288,9 +288,9 @@ class KVConnectorBase_V1(ABC): """ return - def handle_preemptions(self, preempted_req_ids: set[str]): + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): """ - Handle preempted requests BEFORE their blocks are overwritten. + Handle preempted requests or evicted blocks BEFORE they are overwritten. Needed for connectors which use async saves (e.g., OffloadingConnector) """ return diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 7cc80129a..3888d2e0f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -315,10 +315,11 @@ class MultiConnector(KVConnectorBase_V1): for c in self._connectors: c.set_host_xfer_buffer_ops(copy_operation) - def handle_preemptions(self, preempted_req_ids: set[str]): + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): """Handle preempted requests for all sub-connectors.""" - for c in self._connectors: - c.handle_preemptions(preempted_req_ids) + assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata) + for c, cm in zip(self._connectors, kv_connector_metadata.metadata): + c.handle_preemptions(cm) def get_finished_count(self) -> int | None: # TODO(https://github.com/vllm-project/vllm/issues/33400) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 4c850fd2f..d2eebca2c 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -111,6 +111,7 @@ class OffloadingConnectorStats(KVConnectorStats): class OffloadingConnectorMetadata(KVConnectorMetadata): reqs_to_load: dict[ReqId, TransferSpec] reqs_to_store: dict[ReqId, TransferSpec] + reqs_to_flush: set[str] | None = None class OffloadingConnector(KVConnectorBase_V1): @@ -146,9 +147,10 @@ class OffloadingConnector(KVConnectorBase_V1): assert self.connector_worker is not None self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) - def handle_preemptions(self, preempted_req_ids: set[str]): + def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata): assert self.connector_worker is not None - self.connector_worker.handle_preemptions(preempted_req_ids) + assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) + self.connector_worker.handle_preemptions(kv_connector_metadata) def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None @@ -482,6 +484,7 @@ class OffloadingConnectorScheduler: meta = OffloadingConnectorMetadata( reqs_to_load=self._reqs_to_load, reqs_to_store=self._get_reqs_to_store(scheduler_output), + reqs_to_flush=scheduler_output.preempted_req_ids, ) self._reqs_to_load = {} @@ -619,13 +622,13 @@ class OffloadingConnectorWorker: attn_backends = {cross_layer_name: attn_backend} self._register_handlers(kv_caches, attn_backends) - def handle_preemptions(self, preempted_req_ids: set[str]): + def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata): for job_id, transfer_spec in self._unsubmitted_store_jobs: success = self.worker.transfer_async(job_id, transfer_spec) assert success self._unsubmitted_store_jobs.clear() - for req_id in preempted_req_ids: + for req_id in kv_connector_metadata.reqs_to_flush or (): job_ids = self._store_jobs.get(req_id) if job_ids: self.worker.wait(job_ids) diff --git a/vllm/v1/worker/gpu/kv_connector.py b/vllm/v1/worker/gpu/kv_connector.py index 7e4e27e1f..bcbeef1ae 100644 --- a/vllm/v1/worker/gpu/kv_connector.py +++ b/vllm/v1/worker/gpu/kv_connector.py @@ -63,11 +63,10 @@ class ActiveKVConnector(KVConnector): if self._disabled: return - if scheduler_output.preempted_req_ids: - self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids) kv_connector_metadata = scheduler_output.kv_connector_metadata assert kv_connector_metadata is not None self.kv_connector.bind_connector_metadata(kv_connector_metadata) + self.kv_connector.handle_preemptions(kv_connector_metadata) # TODO: sort out KV Connectors' use of forward_context if is_forward_context_available(): diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 22459bc49..a97a0d2dd 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3594,10 +3594,10 @@ class GPUModelRunner( scheduled_spec_decode_tokens=spec_decode_tokens_copy, ) - if scheduler_output.preempted_req_ids and has_kv_transfer_group(): - get_kv_transfer_group().handle_preemptions( - scheduler_output.preempted_req_ids - ) + if has_kv_transfer_group(): + kv_connector_metadata = scheduler_output.kv_connector_metadata + assert kv_connector_metadata is not None + get_kv_transfer_group().handle_preemptions(kv_connector_metadata) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with (