[KVConnector] OffloadingConnector: Fix bug in handling of preemptions (#29870)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-11 10:05:36 +02:00
committed by GitHub
parent bde57ab2ed
commit 4c16ba617f
7 changed files with 248 additions and 57 deletions

View File

@@ -64,8 +64,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
class MockOffloadingHandler(OffloadingHandler):
def __init__(self):
self.transfer_specs: dict[int, TransferSpec] = {}
self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = []
self.waiting_jobs: set[int] = set()
self.completed_jobs: list[int] = []
self.flushed_jobs: set[int] = set()
def get_finished(self) -> list[TransferResult]:
finished = self.completed_transfers
@@ -73,10 +76,21 @@ class MockOffloadingHandler(OffloadingHandler):
return finished
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
self.completed_specs.append(spec)
self.completed_transfers.append((job_id, True))
self.transfer_specs[job_id] = spec
self.waiting_jobs.add(job_id)
return True
def complete_jobs(self, job_ids: set[int]) -> None:
for job_id in job_ids:
if job_id in self.waiting_jobs:
self.waiting_jobs.remove(job_id)
self.completed_jobs.append(job_id)
self.completed_transfers.append((job_id, True))
def wait(self, job_ids: set[int]) -> None:
self.flushed_jobs |= job_ids
self.complete_jobs(job_ids)
class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig):
@@ -98,9 +112,22 @@ class MockOffloadingSpec(OffloadingSpec):
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
def complete_transfers(self):
self.handler.complete_jobs(self.handler.waiting_jobs.copy())
def get_completed_transfers(self) -> list[TransferSpec]:
specs = self.handler.completed_specs
self.handler.completed_specs = []
specs = [
self.handler.transfer_specs[job_id]
for job_id in self.handler.completed_jobs
]
self.handler.completed_jobs.clear()
return specs
def get_flushed_transfers(self):
specs = [
self.handler.transfer_specs[job_id] for job_id in self.handler.flushed_jobs
]
self.handler.flushed_jobs.clear()
return specs
@@ -170,12 +197,9 @@ class RequestRunner:
# mapping (offloading address) -> gpu_block_index
self.offloaded: dict[Any, int] = {}
self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.unsubmitted_stores_count = 0
self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = []
self.flushed_gpu_block_indexes: set[int] = set()
# maps {block_id: block_offset}
self.gpu_block_index: dict[int, int] = {}
@@ -202,54 +226,60 @@ class RequestRunner:
self.scheduler.add_request(req)
def _wait_for_transfers(self):
def _parse_transfers(self):
for transfer_spec in self.offloading_spec.get_flushed_transfers():
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, GPULoadStoreSpec)
for block_id in src_spec.block_ids:
self.flushed_gpu_block_indexes.add(
self.gpu_block_index[block_id.item()]
)
block_size_factor = self.offloaded_block_size // self.gpu_block_size
while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
if store:
assert len(gpu_block_indices) == len(offload_addresses)
if store:
assert len(gpu_block_indices) == len(offload_addresses)
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_stores_count -= 1
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_loads_count -= 1
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
def _update_gpu_block_idx(self):
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
@@ -258,18 +288,19 @@ class RequestRunner:
for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx
def _run(self, decoded_tokens: list[int]):
def _run(self, decoded_tokens: list[int], complete_transfers: bool):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
"""
tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None)
while token_id is not None:
while True:
assert self.scheduler.requests
scheduler_output = self.scheduler.schedule()
@@ -279,10 +310,10 @@ class RequestRunner:
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += self.unsubmitted_stores_count
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store)
if scheduler_output.preempted_req_ids:
self.worker_connector.handle_preemptions(
scheduler_output.preempted_req_ids
)
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
@@ -290,6 +321,9 @@ class RequestRunner:
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
if complete_transfers:
self.offloading_spec.complete_transfers()
finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
@@ -300,7 +334,7 @@ class RequestRunner:
reqs=self.scheduler.running,
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id,
token_id=token_id or 0,
)
if self.scheduler.running:
@@ -308,7 +342,10 @@ class RequestRunner:
self.scheduler.update_from_output(scheduler_output, model_runner_output)
self._wait_for_transfers()
if token_id is None:
break
self._parse_transfers()
# run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens:
@@ -333,8 +370,10 @@ class RequestRunner:
def run(
self,
decoded_tokens: list[int],
complete_transfers: bool = True,
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
expected_flushed_gpu_block_indexes: tuple[int, ...] = (),
):
"""
Runs multiple engine (scheduler + worker) steps.
@@ -342,14 +381,17 @@ class RequestRunner:
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
expected_flushed_gpu_block_indexes: GPU block indexes
that are expected to be flushed during the run.
"""
self.manager.reset_mock()
self._run(decoded_tokens)
self._run(decoded_tokens, complete_transfers)
loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads:
@@ -373,6 +415,9 @@ class RequestRunner:
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
self.completed_stores.clear()
assert set(expected_flushed_gpu_block_indexes) == self.flushed_gpu_block_indexes
self.flushed_gpu_block_indexes.clear()
@pytest.fixture
def request_runner():
@@ -539,3 +584,69 @@ def test_offloading_connector(request_runner):
assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6])
assert event.medium == "B"
def test_request_preemption(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
free_block_queue = runner.scheduler.kv_cache_manager.block_pool.free_block_queue
num_free_blocks_empty = free_block_queue.num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner.new_request(token_ids=[0] * offloaded_block_size * 2)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0],
complete_transfers=False,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
complete_transfers=False,
)
# simulate KV cache running out of space
free_block_queue.num_free_blocks = 0
# request should be preempted now
runner.run(
decoded_tokens=[],
complete_transfers=False,
expected_flushed_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue.num_free_blocks = num_free_blocks_empty
runner.scheduler.reset_prefix_cache()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * gpu_block_size,
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(9, 10, 11),
)

View File

@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler):
del self.transfers[job_id]
return finished
def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished
class OffloadingHandler2To1(OffloadingHandler):
def __init__(self):
@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler):
del self.transfers[job_id]
return finished
def wait(self, job_ids: set[int]) -> None:
for job_id in job_ids:
spec = self.transfers.get(job_id)
if spec:
assert spec.finished
def test_offloading_worker():
"""

View File

@@ -25,6 +25,9 @@ The class provides the following primitives:
Worker-side: runs in each worker, loads/saves KV cache to/from
the Connector based on the metadata.
handle_preemptions() - called if there are preempted requests,
before their blocks are overwritten
start_load_kv() - starts loading all KVs (maybe async)
wait_for_layer_load() - blocks until layer i load is done
@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC):
"""
return
def handle_preemptions(self, preempted_req_ids: set[str]):
"""
Handle preempted requests BEFORE their blocks are overwritten.
Needed for connectors which use async saves (e.g., OffloadingConnector)
"""
return
@abstractmethod
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
"""

View File

@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1):
assert self.connector_worker is not None
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
def handle_preemptions(self, preempted_req_ids: set[str]):
assert self.connector_worker is not None
self.connector_worker.handle_preemptions(preempted_req_ids)
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler:
reqs_to_store=self._get_reqs_to_store(scheduler_output),
)
self._reqs_to_load = {}
# NOTE (orozery): we should move this logic to update_connector_output
# once KVConnectorOutput allows us to report completed transfers
for req_id in scheduler_output.preempted_req_ids or ():
block_hashes = self._reqs_being_stored.get(req_id)
if block_hashes:
self.manager.complete_store(block_hashes)
block_hashes.clear()
return meta
def update_connector_output(self, connector_output: KVConnectorOutput):
@@ -466,6 +479,17 @@ class OffloadingConnectorWorker:
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)
def handle_preemptions(self, preempted_req_ids: set[str]):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id in preempted_req_ids:
job_ids = self._store_jobs.get(req_id)
if job_ids:
self.worker.wait(job_ids)
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)

View File

@@ -96,6 +96,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
# job_id -> event
self._transfer_events: dict[int, torch.Event] = {}
# queue of transfers (job_id, stream, event)
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
# list of CUDA streams available for re-use
@@ -152,6 +154,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
event.record(stream)
self._transfer_events[job_id] = event
self._transfers.append((job_id, stream, event))
# success
@@ -164,8 +167,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
results.append((job_id, True))
self._stream_pool.append(stream)
self._event_pool.append(event)
del self._transfer_events[job_id]
return results
def wait(self, job_ids: set[int]):
for job_id in job_ids:
event = self._transfer_events.get(job_id)
if event is not None:
event.synchronize()
class CpuGpuOffloadingHandlers:
def __init__(

View File

@@ -53,6 +53,15 @@ class OffloadingHandler(ABC):
"""
pass
@abstractmethod
def wait(self, job_ids: set[int]) -> None:
"""
Wait for jobs to finish (blocking).
Args:
job_ids: The set of job IDs to wait for.
"""
class OffloadingWorker:
"""
@@ -142,3 +151,13 @@ class OffloadingWorker:
for handler in self.handlers:
finished.extend(handler.get_finished())
return finished
def wait(self, job_ids: set[int]) -> None:
"""
Wait for jobs to finish (blocking).
Args:
job_ids: The set of job IDs to wait for.
"""
for handler in self.handlers:
handler.wait(job_ids)

View File

@@ -3112,6 +3112,11 @@ class GPUModelRunner(
"after execute_model() returns None."
)
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
get_kv_transfer_group().handle_preemptions(
scheduler_output.preempted_req_ids
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with (
record_function_or_nullcontext("gpu_model_runner: preprocess"),