[kv_offload+HMA][0/N]: Support block-level preemption handling (#34805)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
Or Ozeri
2026-03-18 08:49:53 +02:00
committed by GitHub
parent 86b7e3c95a
commit fcf0687b27
7 changed files with 27 additions and 25 deletions

View File

@@ -231,10 +231,11 @@ def test_multi_example_connector_consistency():
] ]
# First three events are from initialization (register_kv_caches, # First three events are from initialization (register_kv_caches,
# set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events. # set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events.
assert events["storage1-WORKER"][:7] == [ assert events["storage1-WORKER"][:8] == [
"register_kv_caches", "register_kv_caches",
"set_host_xfer_buffer_ops", "set_host_xfer_buffer_ops",
"get_handshake_metadata", "get_handshake_metadata",
"handle_preemptions",
"bind_connector_metadata", "bind_connector_metadata",
"start_load_kv", "start_load_kv",
"wait_for_layer_load", "wait_for_layer_load",
@@ -246,10 +247,11 @@ def test_multi_example_connector_consistency():
"update_state_after_alloc num_blocks=[0] 0", "update_state_after_alloc num_blocks=[0] 0",
"build_connector_meta", "build_connector_meta",
] ]
assert events["storage2-WORKER"][:7] == [ assert events["storage2-WORKER"][:8] == [
"register_kv_caches", "register_kv_caches",
"set_host_xfer_buffer_ops", "set_host_xfer_buffer_ops",
"get_handshake_metadata", "get_handshake_metadata",
"handle_preemptions",
"bind_connector_metadata", "bind_connector_metadata",
"start_load_kv", "start_load_kv",
"wait_for_layer_load", "wait_for_layer_load",
@@ -399,8 +401,8 @@ def test_multi_connector_handle_preemptions_integration():
# testing the delegation behavior of MultiConnector here. # testing the delegation behavior of MultiConnector here.
# The connector attribute contains the KV connector. # The connector attribute contains the KV connector.
assert scheduler.connector is not None, "Scheduler should have a connector" assert scheduler.connector is not None, "Scheduler should have a connector"
preempted_req_ids = {"req-1", "req-2", "req-3"} connector_md = scheduler.connector.build_connector_meta(scheduler.schedule())
scheduler.connector.handle_preemptions(preempted_req_ids) scheduler.connector.handle_preemptions(connector_md)
# Verify both connectors received the handle_preemptions call # Verify both connectors received the handle_preemptions call
events = get_connector_events() events = get_connector_events()

View File

@@ -363,10 +363,7 @@ class RequestRunner:
assert kv_connector_metadata is not None assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
if scheduler_output.preempted_req_ids: self.worker_connector.handle_preemptions(kv_connector_metadata)
self.worker_connector.handle_preemptions(
scheduler_output.preempted_req_ids
)
self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx) self.worker_connector.start_load_kv(self._dummy_ctx)

View File

@@ -25,8 +25,8 @@ The class provides the following primitives:
Worker-side: runs in each worker, loads/saves KV cache to/from Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata. the Connector based on the metadata.
handle_preemptions() - called if there are preempted requests, handle_preemptions() - called for handling preempted requests
before their blocks are overwritten or request evicted blocks before they are overwritten
start_load_kv() - starts loading all KVs (maybe async) start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done wait_for_layer_load() - blocks until layer i load is done
@@ -288,9 +288,9 @@ class KVConnectorBase_V1(ABC):
""" """
return return
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
""" """
Handle preempted requests BEFORE their blocks are overwritten. Handle preempted requests or evicted blocks BEFORE they are overwritten.
Needed for connectors which use async saves (e.g., OffloadingConnector) Needed for connectors which use async saves (e.g., OffloadingConnector)
""" """
return return

View File

@@ -315,10 +315,11 @@ class MultiConnector(KVConnectorBase_V1):
for c in self._connectors: for c in self._connectors:
c.set_host_xfer_buffer_ops(copy_operation) c.set_host_xfer_buffer_ops(copy_operation)
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
"""Handle preempted requests for all sub-connectors.""" """Handle preempted requests for all sub-connectors."""
for c in self._connectors: assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata)
c.handle_preemptions(preempted_req_ids) for c, cm in zip(self._connectors, kv_connector_metadata.metadata):
c.handle_preemptions(cm)
def get_finished_count(self) -> int | None: def get_finished_count(self) -> int | None:
# TODO(https://github.com/vllm-project/vllm/issues/33400) # TODO(https://github.com/vllm-project/vllm/issues/33400)

View File

@@ -111,6 +111,7 @@ class OffloadingConnectorStats(KVConnectorStats):
class OffloadingConnectorMetadata(KVConnectorMetadata): class OffloadingConnectorMetadata(KVConnectorMetadata):
reqs_to_load: dict[ReqId, TransferSpec] reqs_to_load: dict[ReqId, TransferSpec]
reqs_to_store: dict[ReqId, TransferSpec] reqs_to_store: dict[ReqId, TransferSpec]
reqs_to_flush: set[str] | None = None
class OffloadingConnector(KVConnectorBase_V1): class OffloadingConnector(KVConnectorBase_V1):
@@ -146,9 +147,10 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend) self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
assert self.connector_worker is not None assert self.connector_worker is not None
self.connector_worker.handle_preemptions(preempted_req_ids) assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.handle_preemptions(kv_connector_metadata)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None assert self.connector_worker is not None
@@ -482,6 +484,7 @@ class OffloadingConnectorScheduler:
meta = OffloadingConnectorMetadata( meta = OffloadingConnectorMetadata(
reqs_to_load=self._reqs_to_load, reqs_to_load=self._reqs_to_load,
reqs_to_store=self._get_reqs_to_store(scheduler_output), reqs_to_store=self._get_reqs_to_store(scheduler_output),
reqs_to_flush=scheduler_output.preempted_req_ids,
) )
self._reqs_to_load = {} self._reqs_to_load = {}
@@ -619,13 +622,13 @@ class OffloadingConnectorWorker:
attn_backends = {cross_layer_name: attn_backend} attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends) self._register_handlers(kv_caches, attn_backends)
def handle_preemptions(self, preempted_req_ids: set[str]): def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs: for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec) success = self.worker.transfer_async(job_id, transfer_spec)
assert success assert success
self._unsubmitted_store_jobs.clear() self._unsubmitted_store_jobs.clear()
for req_id in preempted_req_ids: for req_id in kv_connector_metadata.reqs_to_flush or ():
job_ids = self._store_jobs.get(req_id) job_ids = self._store_jobs.get(req_id)
if job_ids: if job_ids:
self.worker.wait(job_ids) self.worker.wait(job_ids)

View File

@@ -63,11 +63,10 @@ class ActiveKVConnector(KVConnector):
if self._disabled: if self._disabled:
return return
if scheduler_output.preempted_req_ids:
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
kv_connector_metadata = scheduler_output.kv_connector_metadata kv_connector_metadata = scheduler_output.kv_connector_metadata
assert kv_connector_metadata is not None assert kv_connector_metadata is not None
self.kv_connector.bind_connector_metadata(kv_connector_metadata) self.kv_connector.bind_connector_metadata(kv_connector_metadata)
self.kv_connector.handle_preemptions(kv_connector_metadata)
# TODO: sort out KV Connectors' use of forward_context # TODO: sort out KV Connectors' use of forward_context
if is_forward_context_available(): if is_forward_context_available():

View File

@@ -3594,10 +3594,10 @@ class GPUModelRunner(
scheduled_spec_decode_tokens=spec_decode_tokens_copy, scheduled_spec_decode_tokens=spec_decode_tokens_copy,
) )
if scheduler_output.preempted_req_ids and has_kv_transfer_group(): if has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions( kv_connector_metadata = scheduler_output.kv_connector_metadata
scheduler_output.preempted_req_ids assert kv_connector_metadata is not None
) get_kv_transfer_group().handle_preemptions(kv_connector_metadata)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with ( with (