[kv_offload+HMA][0/N]: Support block-level preemption handling (#34805)
Signed-off-by: Or Ozeri <oro@il.ibm.com> Co-authored-by: Nicolò Lucchesi <nlucches@redhat.com>
This commit is contained in:
@@ -231,10 +231,11 @@ def test_multi_example_connector_consistency():
|
|||||||
]
|
]
|
||||||
# First three events are from initialization (register_kv_caches,
|
# First three events are from initialization (register_kv_caches,
|
||||||
# set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events.
|
# set_host_xfer_buffer_ops, get_handshake_metadata), then generate() events.
|
||||||
assert events["storage1-WORKER"][:7] == [
|
assert events["storage1-WORKER"][:8] == [
|
||||||
"register_kv_caches",
|
"register_kv_caches",
|
||||||
"set_host_xfer_buffer_ops",
|
"set_host_xfer_buffer_ops",
|
||||||
"get_handshake_metadata",
|
"get_handshake_metadata",
|
||||||
|
"handle_preemptions",
|
||||||
"bind_connector_metadata",
|
"bind_connector_metadata",
|
||||||
"start_load_kv",
|
"start_load_kv",
|
||||||
"wait_for_layer_load",
|
"wait_for_layer_load",
|
||||||
@@ -246,10 +247,11 @@ def test_multi_example_connector_consistency():
|
|||||||
"update_state_after_alloc num_blocks=[0] 0",
|
"update_state_after_alloc num_blocks=[0] 0",
|
||||||
"build_connector_meta",
|
"build_connector_meta",
|
||||||
]
|
]
|
||||||
assert events["storage2-WORKER"][:7] == [
|
assert events["storage2-WORKER"][:8] == [
|
||||||
"register_kv_caches",
|
"register_kv_caches",
|
||||||
"set_host_xfer_buffer_ops",
|
"set_host_xfer_buffer_ops",
|
||||||
"get_handshake_metadata",
|
"get_handshake_metadata",
|
||||||
|
"handle_preemptions",
|
||||||
"bind_connector_metadata",
|
"bind_connector_metadata",
|
||||||
"start_load_kv",
|
"start_load_kv",
|
||||||
"wait_for_layer_load",
|
"wait_for_layer_load",
|
||||||
@@ -399,8 +401,8 @@ def test_multi_connector_handle_preemptions_integration():
|
|||||||
# testing the delegation behavior of MultiConnector here.
|
# testing the delegation behavior of MultiConnector here.
|
||||||
# The connector attribute contains the KV connector.
|
# The connector attribute contains the KV connector.
|
||||||
assert scheduler.connector is not None, "Scheduler should have a connector"
|
assert scheduler.connector is not None, "Scheduler should have a connector"
|
||||||
preempted_req_ids = {"req-1", "req-2", "req-3"}
|
connector_md = scheduler.connector.build_connector_meta(scheduler.schedule())
|
||||||
scheduler.connector.handle_preemptions(preempted_req_ids)
|
scheduler.connector.handle_preemptions(connector_md)
|
||||||
|
|
||||||
# Verify both connectors received the handle_preemptions call
|
# Verify both connectors received the handle_preemptions call
|
||||||
events = get_connector_events()
|
events = get_connector_events()
|
||||||
|
|||||||
@@ -363,10 +363,7 @@ class RequestRunner:
|
|||||||
assert kv_connector_metadata is not None
|
assert kv_connector_metadata is not None
|
||||||
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
||||||
|
|
||||||
if scheduler_output.preempted_req_ids:
|
self.worker_connector.handle_preemptions(kv_connector_metadata)
|
||||||
self.worker_connector.handle_preemptions(
|
|
||||||
scheduler_output.preempted_req_ids
|
|
||||||
)
|
|
||||||
|
|
||||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ The class provides the following primitives:
|
|||||||
|
|
||||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||||
the Connector based on the metadata.
|
the Connector based on the metadata.
|
||||||
handle_preemptions() - called if there are preempted requests,
|
handle_preemptions() - called for handling preempted requests
|
||||||
before their blocks are overwritten
|
or request evicted blocks before they are overwritten
|
||||||
|
|
||||||
start_load_kv() - starts loading all KVs (maybe async)
|
start_load_kv() - starts loading all KVs (maybe async)
|
||||||
wait_for_layer_load() - blocks until layer i load is done
|
wait_for_layer_load() - blocks until layer i load is done
|
||||||
@@ -288,9 +288,9 @@ class KVConnectorBase_V1(ABC):
|
|||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
|
||||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
|
||||||
"""
|
"""
|
||||||
Handle preempted requests BEFORE their blocks are overwritten.
|
Handle preempted requests or evicted blocks BEFORE they are overwritten.
|
||||||
Needed for connectors which use async saves (e.g., OffloadingConnector)
|
Needed for connectors which use async saves (e.g., OffloadingConnector)
|
||||||
"""
|
"""
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -315,10 +315,11 @@ class MultiConnector(KVConnectorBase_V1):
|
|||||||
for c in self._connectors:
|
for c in self._connectors:
|
||||||
c.set_host_xfer_buffer_ops(copy_operation)
|
c.set_host_xfer_buffer_ops(copy_operation)
|
||||||
|
|
||||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
|
||||||
"""Handle preempted requests for all sub-connectors."""
|
"""Handle preempted requests for all sub-connectors."""
|
||||||
for c in self._connectors:
|
assert isinstance(kv_connector_metadata, MultiKVConnectorMetadata)
|
||||||
c.handle_preemptions(preempted_req_ids)
|
for c, cm in zip(self._connectors, kv_connector_metadata.metadata):
|
||||||
|
c.handle_preemptions(cm)
|
||||||
|
|
||||||
def get_finished_count(self) -> int | None:
|
def get_finished_count(self) -> int | None:
|
||||||
# TODO(https://github.com/vllm-project/vllm/issues/33400)
|
# TODO(https://github.com/vllm-project/vllm/issues/33400)
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ class OffloadingConnectorStats(KVConnectorStats):
|
|||||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||||
reqs_to_load: dict[ReqId, TransferSpec]
|
reqs_to_load: dict[ReqId, TransferSpec]
|
||||||
reqs_to_store: dict[ReqId, TransferSpec]
|
reqs_to_store: dict[ReqId, TransferSpec]
|
||||||
|
reqs_to_flush: set[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
class OffloadingConnector(KVConnectorBase_V1):
|
class OffloadingConnector(KVConnectorBase_V1):
|
||||||
@@ -146,9 +147,10 @@ class OffloadingConnector(KVConnectorBase_V1):
|
|||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||||
|
|
||||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
def handle_preemptions(self, kv_connector_metadata: KVConnectorMetadata):
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
self.connector_worker.handle_preemptions(preempted_req_ids)
|
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
||||||
|
self.connector_worker.handle_preemptions(kv_connector_metadata)
|
||||||
|
|
||||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||||
assert self.connector_worker is not None
|
assert self.connector_worker is not None
|
||||||
@@ -482,6 +484,7 @@ class OffloadingConnectorScheduler:
|
|||||||
meta = OffloadingConnectorMetadata(
|
meta = OffloadingConnectorMetadata(
|
||||||
reqs_to_load=self._reqs_to_load,
|
reqs_to_load=self._reqs_to_load,
|
||||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||||
|
reqs_to_flush=scheduler_output.preempted_req_ids,
|
||||||
)
|
)
|
||||||
self._reqs_to_load = {}
|
self._reqs_to_load = {}
|
||||||
|
|
||||||
@@ -619,13 +622,13 @@ class OffloadingConnectorWorker:
|
|||||||
attn_backends = {cross_layer_name: attn_backend}
|
attn_backends = {cross_layer_name: attn_backend}
|
||||||
self._register_handlers(kv_caches, attn_backends)
|
self._register_handlers(kv_caches, attn_backends)
|
||||||
|
|
||||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata):
|
||||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||||
assert success
|
assert success
|
||||||
self._unsubmitted_store_jobs.clear()
|
self._unsubmitted_store_jobs.clear()
|
||||||
|
|
||||||
for req_id in preempted_req_ids:
|
for req_id in kv_connector_metadata.reqs_to_flush or ():
|
||||||
job_ids = self._store_jobs.get(req_id)
|
job_ids = self._store_jobs.get(req_id)
|
||||||
if job_ids:
|
if job_ids:
|
||||||
self.worker.wait(job_ids)
|
self.worker.wait(job_ids)
|
||||||
|
|||||||
@@ -63,11 +63,10 @@ class ActiveKVConnector(KVConnector):
|
|||||||
if self._disabled:
|
if self._disabled:
|
||||||
return
|
return
|
||||||
|
|
||||||
if scheduler_output.preempted_req_ids:
|
|
||||||
self.kv_connector.handle_preemptions(scheduler_output.preempted_req_ids)
|
|
||||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
assert kv_connector_metadata is not None
|
assert kv_connector_metadata is not None
|
||||||
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
|
self.kv_connector.bind_connector_metadata(kv_connector_metadata)
|
||||||
|
self.kv_connector.handle_preemptions(kv_connector_metadata)
|
||||||
|
|
||||||
# TODO: sort out KV Connectors' use of forward_context
|
# TODO: sort out KV Connectors' use of forward_context
|
||||||
if is_forward_context_available():
|
if is_forward_context_available():
|
||||||
|
|||||||
@@ -3594,10 +3594,10 @@ class GPUModelRunner(
|
|||||||
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
|
scheduled_spec_decode_tokens=spec_decode_tokens_copy,
|
||||||
)
|
)
|
||||||
|
|
||||||
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
|
if has_kv_transfer_group():
|
||||||
get_kv_transfer_group().handle_preemptions(
|
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||||
scheduler_output.preempted_req_ids
|
assert kv_connector_metadata is not None
|
||||||
)
|
get_kv_transfer_group().handle_preemptions(kv_connector_metadata)
|
||||||
|
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
with (
|
with (
|
||||||
|
|||||||
Reference in New Issue
Block a user