[KVConnector] OffloadingConnector: Fix bug in handling of preemptions (#29870)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-11 10:05:36 +02:00
committed by GitHub
parent bde57ab2ed
commit 4c16ba617f
7 changed files with 248 additions and 57 deletions

View File

@@ -64,8 +64,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
class MockOffloadingHandler(OffloadingHandler):
def __init__(self):
self.transfer_specs: dict[int, TransferSpec] = {}
self.completed_transfers: list[TransferResult] = []
self.completed_specs: list[TransferSpec] = []
self.waiting_jobs: set[int] = set()
self.completed_jobs: list[int] = []
self.flushed_jobs: set[int] = set()
def get_finished(self) -> list[TransferResult]:
finished = self.completed_transfers
@@ -73,10 +76,21 @@ class MockOffloadingHandler(OffloadingHandler):
return finished
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
self.completed_specs.append(spec)
self.completed_transfers.append((job_id, True))
self.transfer_specs[job_id] = spec
self.waiting_jobs.add(job_id)
return True
def complete_jobs(self, job_ids: set[int]) -> None:
for job_id in job_ids:
if job_id in self.waiting_jobs:
self.waiting_jobs.remove(job_id)
self.completed_jobs.append(job_id)
self.completed_transfers.append((job_id, True))
def wait(self, job_ids: set[int]) -> None:
self.flushed_jobs |= job_ids
self.complete_jobs(job_ids)
class MockOffloadingSpec(OffloadingSpec):
def __init__(self, vllm_config: VllmConfig):
@@ -98,9 +112,22 @@ class MockOffloadingSpec(OffloadingSpec):
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
def complete_transfers(self):
self.handler.complete_jobs(self.handler.waiting_jobs.copy())
def get_completed_transfers(self) -> list[TransferSpec]:
specs = self.handler.completed_specs
self.handler.completed_specs = []
specs = [
self.handler.transfer_specs[job_id]
for job_id in self.handler.completed_jobs
]
self.handler.completed_jobs.clear()
return specs
def get_flushed_transfers(self):
specs = [
self.handler.transfer_specs[job_id] for job_id in self.handler.flushed_jobs
]
self.handler.flushed_jobs.clear()
return specs
@@ -170,12 +197,9 @@ class RequestRunner:
# mapping (offloading address) -> gpu_block_index
self.offloaded: dict[Any, int] = {}
self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.unsubmitted_stores_count = 0
self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = []
self.flushed_gpu_block_indexes: set[int] = set()
# maps {block_id: block_offset}
self.gpu_block_index: dict[int, int] = {}
@@ -202,54 +226,60 @@ class RequestRunner:
self.scheduler.add_request(req)
def _wait_for_transfers(self):
def _parse_transfers(self):
for transfer_spec in self.offloading_spec.get_flushed_transfers():
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, GPULoadStoreSpec)
for block_id in src_spec.block_ids:
self.flushed_gpu_block_indexes.add(
self.gpu_block_index[block_id.item()]
)
block_size_factor = self.offloaded_block_size // self.gpu_block_size
while self.pending_loads_count or self.pending_stores_count:
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
for transfer_spec in self.offloading_spec.get_completed_transfers():
src_spec, dst_spec = transfer_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
if isinstance(src_spec, GPULoadStoreSpec):
store = True
gpu_spec = src_spec
offload_spec = dst_spec
else:
store = False
gpu_spec = dst_spec
offload_spec = src_spec
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
assert isinstance(offload_spec, MockLoadStoreSpec)
assert isinstance(gpu_spec, GPULoadStoreSpec)
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
gpu_block_indices: list[int] = []
for block_id in gpu_spec.block_ids:
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
# list of (block_hash, sub_block_offset)
offload_addresses: list[Any] = []
for block_hash in offload_spec.block_hashes:
for sub_block_idx in range(block_size_factor):
offload_addresses.append((block_hash, sub_block_idx))
if store:
assert len(gpu_block_indices) == len(offload_addresses)
if store:
assert len(gpu_block_indices) == len(offload_addresses)
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_stores_count -= 1
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_stores.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
else:
remainder_sub_block_count = len(offload_addresses) - len(
gpu_block_indices
)
assert remainder_sub_block_count >= 0
assert remainder_sub_block_count < block_size_factor
offload_addresses = offload_addresses[remainder_sub_block_count:]
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
self.pending_loads_count -= 1
self.completed_loads.append(
TransferSummary(gpu_block_indices, offload_addresses)
)
def _update_gpu_block_idx(self):
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
@@ -258,18 +288,19 @@ class RequestRunner:
for block_idx, block in enumerate(blocks):
self.gpu_block_index[block.block_id] = block_idx
def _run(self, decoded_tokens: list[int]):
def _run(self, decoded_tokens: list[int], complete_transfers: bool):
"""
Runs multiple engine (scheduler + worker) steps.
Assumes a single request is running.
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
"""
tokens_iter = iter(decoded_tokens)
token_id = next(tokens_iter, None)
while token_id is not None:
while True:
assert self.scheduler.requests
scheduler_output = self.scheduler.schedule()
@@ -279,10 +310,10 @@ class RequestRunner:
assert kv_connector_metadata is not None
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += self.unsubmitted_stores_count
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store)
if scheduler_output.preempted_req_ids:
self.worker_connector.handle_preemptions(
scheduler_output.preempted_req_ids
)
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
@@ -290,6 +321,9 @@ class RequestRunner:
if scheduler_output.total_num_scheduled_tokens > 0:
self.worker_connector.wait_for_save()
if complete_transfers:
self.offloading_spec.complete_transfers()
finished_sending, finished_recving = self.worker_connector.get_finished(
scheduler_output.finished_req_ids
)
@@ -300,7 +334,7 @@ class RequestRunner:
reqs=self.scheduler.running,
finished_sending=finished_sending,
finished_recving=finished_recving,
token_id=token_id,
token_id=token_id or 0,
)
if self.scheduler.running:
@@ -308,7 +342,10 @@ class RequestRunner:
self.scheduler.update_from_output(scheduler_output, model_runner_output)
self._wait_for_transfers()
if token_id is None:
break
self._parse_transfers()
# run one more step to update finished stored
if EOS_TOKEN_ID in decoded_tokens:
@@ -333,8 +370,10 @@ class RequestRunner:
def run(
self,
decoded_tokens: list[int],
complete_transfers: bool = True,
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
expected_flushed_gpu_block_indexes: tuple[int, ...] = (),
):
"""
Runs multiple engine (scheduler + worker) steps.
@@ -342,14 +381,17 @@ class RequestRunner:
Args:
decoded_tokens: the tokens to yield at each step.
complete_transfers: complete transfers immediately
expected_stored_gpu_block_indexes: GPU block indexes
that are expected to be written during the run.
expected_loaded_gpu_block_indexes: GPU block indexes
that are expected to be loaded during the run.
expected_flushed_gpu_block_indexes: GPU block indexes
that are expected to be flushed during the run.
"""
self.manager.reset_mock()
self._run(decoded_tokens)
self._run(decoded_tokens, complete_transfers)
loaded_gpu_block_indexes: set[int] = set()
for transfer in self.completed_loads:
@@ -373,6 +415,9 @@ class RequestRunner:
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
self.completed_stores.clear()
assert set(expected_flushed_gpu_block_indexes) == self.flushed_gpu_block_indexes
self.flushed_gpu_block_indexes.clear()
@pytest.fixture
def request_runner():
@@ -539,3 +584,69 @@ def test_offloading_connector(request_runner):
assert isinstance(event, BlockRemoved)
assert event.block_hashes == to_hashes([4, 5, 6])
assert event.medium == "B"
def test_request_preemption(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
free_block_queue = runner.scheduler.kv_cache_manager.block_pool.free_block_queue
num_free_blocks_empty = free_block_queue.num_free_blocks
# 2 blocks, store all, without flushing
# blocks = [0, 1, 2], [3, 4, 5]
runner.new_request(token_ids=[0] * offloaded_block_size * 2)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0],
complete_transfers=False,
)
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
complete_transfers=False,
)
# simulate KV cache running out of space
free_block_queue.num_free_blocks = 0
# request should be preempted now
runner.run(
decoded_tokens=[],
complete_transfers=False,
expected_flushed_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
# restore KV cache space and reset GPU prefix cache
free_block_queue.num_free_blocks = num_free_blocks_empty
runner.scheduler.reset_prefix_cache()
# request should now return from preemption
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
runner.manager.lookup.return_value = 3
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * gpu_block_size,
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(9, 10, 11),
)