[KVConnector] OffloadingConnector: Fix bug in handling of preemptions (#29870)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -64,8 +64,11 @@ class MockLoadStoreSpec(LoadStoreSpec):
|
||||
|
||||
class MockOffloadingHandler(OffloadingHandler):
|
||||
def __init__(self):
|
||||
self.transfer_specs: dict[int, TransferSpec] = {}
|
||||
self.completed_transfers: list[TransferResult] = []
|
||||
self.completed_specs: list[TransferSpec] = []
|
||||
self.waiting_jobs: set[int] = set()
|
||||
self.completed_jobs: list[int] = []
|
||||
self.flushed_jobs: set[int] = set()
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = self.completed_transfers
|
||||
@@ -73,10 +76,21 @@ class MockOffloadingHandler(OffloadingHandler):
|
||||
return finished
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
self.completed_specs.append(spec)
|
||||
self.completed_transfers.append((job_id, True))
|
||||
self.transfer_specs[job_id] = spec
|
||||
self.waiting_jobs.add(job_id)
|
||||
return True
|
||||
|
||||
def complete_jobs(self, job_ids: set[int]) -> None:
|
||||
for job_id in job_ids:
|
||||
if job_id in self.waiting_jobs:
|
||||
self.waiting_jobs.remove(job_id)
|
||||
self.completed_jobs.append(job_id)
|
||||
self.completed_transfers.append((job_id, True))
|
||||
|
||||
def wait(self, job_ids: set[int]) -> None:
|
||||
self.flushed_jobs |= job_ids
|
||||
self.complete_jobs(job_ids)
|
||||
|
||||
|
||||
class MockOffloadingSpec(OffloadingSpec):
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
@@ -98,9 +112,22 @@ class MockOffloadingSpec(OffloadingSpec):
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
|
||||
def complete_transfers(self):
|
||||
self.handler.complete_jobs(self.handler.waiting_jobs.copy())
|
||||
|
||||
def get_completed_transfers(self) -> list[TransferSpec]:
|
||||
specs = self.handler.completed_specs
|
||||
self.handler.completed_specs = []
|
||||
specs = [
|
||||
self.handler.transfer_specs[job_id]
|
||||
for job_id in self.handler.completed_jobs
|
||||
]
|
||||
self.handler.completed_jobs.clear()
|
||||
return specs
|
||||
|
||||
def get_flushed_transfers(self):
|
||||
specs = [
|
||||
self.handler.transfer_specs[job_id] for job_id in self.handler.flushed_jobs
|
||||
]
|
||||
self.handler.flushed_jobs.clear()
|
||||
return specs
|
||||
|
||||
|
||||
@@ -170,12 +197,9 @@ class RequestRunner:
|
||||
# mapping (offloading address) -> gpu_block_index
|
||||
self.offloaded: dict[Any, int] = {}
|
||||
|
||||
self.pending_loads_count: int = 0
|
||||
self.pending_stores_count: int = 0
|
||||
self.unsubmitted_stores_count = 0
|
||||
|
||||
self.completed_loads: list[TransferSummary] = []
|
||||
self.completed_stores: list[TransferSummary] = []
|
||||
self.flushed_gpu_block_indexes: set[int] = set()
|
||||
|
||||
# maps {block_id: block_offset}
|
||||
self.gpu_block_index: dict[int, int] = {}
|
||||
@@ -202,54 +226,60 @@ class RequestRunner:
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def _wait_for_transfers(self):
|
||||
def _parse_transfers(self):
|
||||
for transfer_spec in self.offloading_spec.get_flushed_transfers():
|
||||
src_spec, dst_spec = transfer_spec
|
||||
assert isinstance(src_spec, GPULoadStoreSpec)
|
||||
|
||||
for block_id in src_spec.block_ids:
|
||||
self.flushed_gpu_block_indexes.add(
|
||||
self.gpu_block_index[block_id.item()]
|
||||
)
|
||||
|
||||
block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
|
||||
while self.pending_loads_count or self.pending_stores_count:
|
||||
for transfer_spec in self.offloading_spec.get_completed_transfers():
|
||||
src_spec, dst_spec = transfer_spec
|
||||
for transfer_spec in self.offloading_spec.get_completed_transfers():
|
||||
src_spec, dst_spec = transfer_spec
|
||||
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
store = True
|
||||
gpu_spec = src_spec
|
||||
offload_spec = dst_spec
|
||||
else:
|
||||
store = False
|
||||
gpu_spec = dst_spec
|
||||
offload_spec = src_spec
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
store = True
|
||||
gpu_spec = src_spec
|
||||
offload_spec = dst_spec
|
||||
else:
|
||||
store = False
|
||||
gpu_spec = dst_spec
|
||||
offload_spec = src_spec
|
||||
|
||||
assert isinstance(offload_spec, MockLoadStoreSpec)
|
||||
assert isinstance(gpu_spec, GPULoadStoreSpec)
|
||||
assert isinstance(offload_spec, MockLoadStoreSpec)
|
||||
assert isinstance(gpu_spec, GPULoadStoreSpec)
|
||||
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(self.gpu_block_index[block_id.item()])
|
||||
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
for block_hash in offload_spec.block_hashes:
|
||||
for sub_block_idx in range(block_size_factor):
|
||||
offload_addresses.append((block_hash, sub_block_idx))
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
for block_hash in offload_spec.block_hashes:
|
||||
for sub_block_idx in range(block_size_factor):
|
||||
offload_addresses.append((block_hash, sub_block_idx))
|
||||
|
||||
if store:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
if store:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_stores_count -= 1
|
||||
else:
|
||||
remainder_sub_block_count = len(offload_addresses) - len(
|
||||
gpu_block_indices
|
||||
)
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[remainder_sub_block_count:]
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
else:
|
||||
remainder_sub_block_count = len(offload_addresses) - len(
|
||||
gpu_block_indices
|
||||
)
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[remainder_sub_block_count:]
|
||||
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
self.pending_loads_count -= 1
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses)
|
||||
)
|
||||
|
||||
def _update_gpu_block_idx(self):
|
||||
for blocks in self.scheduler.kv_cache_manager.coordinator.single_type_managers[
|
||||
@@ -258,18 +288,19 @@ class RequestRunner:
|
||||
for block_idx, block in enumerate(blocks):
|
||||
self.gpu_block_index[block.block_id] = block_idx
|
||||
|
||||
def _run(self, decoded_tokens: list[int]):
|
||||
def _run(self, decoded_tokens: list[int], complete_transfers: bool):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
complete_transfers: complete transfers immediately
|
||||
"""
|
||||
|
||||
tokens_iter = iter(decoded_tokens)
|
||||
token_id = next(tokens_iter, None)
|
||||
while token_id is not None:
|
||||
while True:
|
||||
assert self.scheduler.requests
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
@@ -279,10 +310,10 @@ class RequestRunner:
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
|
||||
self.pending_stores_count += self.unsubmitted_stores_count
|
||||
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store)
|
||||
if scheduler_output.preempted_req_ids:
|
||||
self.worker_connector.handle_preemptions(
|
||||
scheduler_output.preempted_req_ids
|
||||
)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
@@ -290,6 +321,9 @@ class RequestRunner:
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
if complete_transfers:
|
||||
self.offloading_spec.complete_transfers()
|
||||
|
||||
finished_sending, finished_recving = self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids
|
||||
)
|
||||
@@ -300,7 +334,7 @@ class RequestRunner:
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=finished_sending,
|
||||
finished_recving=finished_recving,
|
||||
token_id=token_id,
|
||||
token_id=token_id or 0,
|
||||
)
|
||||
|
||||
if self.scheduler.running:
|
||||
@@ -308,7 +342,10 @@ class RequestRunner:
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
self._wait_for_transfers()
|
||||
if token_id is None:
|
||||
break
|
||||
|
||||
self._parse_transfers()
|
||||
|
||||
# run one more step to update finished stored
|
||||
if EOS_TOKEN_ID in decoded_tokens:
|
||||
@@ -333,8 +370,10 @@ class RequestRunner:
|
||||
def run(
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
complete_transfers: bool = True,
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_flushed_gpu_block_indexes: tuple[int, ...] = (),
|
||||
):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
@@ -342,14 +381,17 @@ class RequestRunner:
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
complete_transfers: complete transfers immediately
|
||||
expected_stored_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be written during the run.
|
||||
expected_loaded_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be loaded during the run.
|
||||
expected_flushed_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be flushed during the run.
|
||||
"""
|
||||
|
||||
self.manager.reset_mock()
|
||||
self._run(decoded_tokens)
|
||||
self._run(decoded_tokens, complete_transfers)
|
||||
|
||||
loaded_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_loads:
|
||||
@@ -373,6 +415,9 @@ class RequestRunner:
|
||||
assert set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes
|
||||
self.completed_stores.clear()
|
||||
|
||||
assert set(expected_flushed_gpu_block_indexes) == self.flushed_gpu_block_indexes
|
||||
self.flushed_gpu_block_indexes.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_runner():
|
||||
@@ -539,3 +584,69 @@ def test_offloading_connector(request_runner):
|
||||
assert isinstance(event, BlockRemoved)
|
||||
assert event.block_hashes == to_hashes([4, 5, 6])
|
||||
assert event.medium == "B"
|
||||
|
||||
|
||||
def test_request_preemption(request_runner):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = request_runner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
free_block_queue = runner.scheduler.kv_cache_manager.block_pool.free_block_queue
|
||||
num_free_blocks_empty = free_block_queue.num_free_blocks
|
||||
|
||||
# 2 blocks, store all, without flushing
|
||||
# blocks = [0, 1, 2], [3, 4, 5]
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 2)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[0],
|
||||
complete_transfers=False,
|
||||
)
|
||||
|
||||
# decode 2 more blocks - 1 gpu block, storing [6, 7, 8] (no flush)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[0] * (2 * offloaded_block_size - gpu_block_size),
|
||||
complete_transfers=False,
|
||||
)
|
||||
|
||||
# simulate KV cache running out of space
|
||||
free_block_queue.num_free_blocks = 0
|
||||
|
||||
# request should be preempted now
|
||||
runner.run(
|
||||
decoded_tokens=[],
|
||||
complete_transfers=False,
|
||||
expected_flushed_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
|
||||
expected_stored_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
|
||||
)
|
||||
|
||||
# restore KV cache space and reset GPU prefix cache
|
||||
free_block_queue.num_free_blocks = num_free_blocks_empty
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
|
||||
# request should now return from preemption
|
||||
# re-load [0, ..., 8] from the CPU and store [9, 10, 11]
|
||||
runner.manager.lookup.return_value = 3
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[0] * gpu_block_size,
|
||||
expected_loaded_gpu_block_indexes=(0, 1, 2, 3, 4, 5, 6, 7, 8),
|
||||
)
|
||||
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_stored_gpu_block_indexes=(9, 10, 11),
|
||||
)
|
||||
|
||||
@@ -63,6 +63,12 @@ class OffloadingHandler1To2(OffloadingHandler):
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
def wait(self, job_ids: set[int]) -> None:
|
||||
for job_id in job_ids:
|
||||
spec = self.transfers.get(job_id)
|
||||
if spec:
|
||||
assert spec.finished
|
||||
|
||||
|
||||
class OffloadingHandler2To1(OffloadingHandler):
|
||||
def __init__(self):
|
||||
@@ -84,6 +90,12 @@ class OffloadingHandler2To1(OffloadingHandler):
|
||||
del self.transfers[job_id]
|
||||
return finished
|
||||
|
||||
def wait(self, job_ids: set[int]) -> None:
|
||||
for job_id in job_ids:
|
||||
spec = self.transfers.get(job_id)
|
||||
if spec:
|
||||
assert spec.finished
|
||||
|
||||
|
||||
def test_offloading_worker():
|
||||
"""
|
||||
|
||||
@@ -25,6 +25,9 @@ The class provides the following primitives:
|
||||
|
||||
Worker-side: runs in each worker, loads/saves KV cache to/from
|
||||
the Connector based on the metadata.
|
||||
handle_preemptions() - called if there are preempted requests,
|
||||
before their blocks are overwritten
|
||||
|
||||
start_load_kv() - starts loading all KVs (maybe async)
|
||||
wait_for_layer_load() - blocks until layer i load is done
|
||||
|
||||
@@ -262,6 +265,13 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
"""
|
||||
Handle preempted requests BEFORE their blocks are overwritten.
|
||||
Needed for connectors which use async saves (e.g., OffloadingConnector)
|
||||
"""
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs: Any) -> None:
|
||||
"""
|
||||
|
||||
@@ -75,6 +75,10 @@ class OffloadingConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_cross_layers_kv_cache(kv_cache, attn_backend)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.handle_preemptions(preempted_req_ids)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
@@ -348,6 +352,15 @@ class OffloadingConnectorScheduler:
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output),
|
||||
)
|
||||
self._reqs_to_load = {}
|
||||
|
||||
# NOTE (orozery): we should move this logic to update_connector_output
|
||||
# once KVConnectorOutput allows us to report completed transfers
|
||||
for req_id in scheduler_output.preempted_req_ids or ():
|
||||
block_hashes = self._reqs_being_stored.get(req_id)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
block_hashes.clear()
|
||||
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
@@ -466,6 +479,17 @@ class OffloadingConnectorWorker:
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def handle_preemptions(self, preempted_req_ids: set[str]):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id in preempted_req_ids:
|
||||
job_ids = self._store_jobs.get(req_id)
|
||||
if job_ids:
|
||||
self.worker.wait(job_ids)
|
||||
|
||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
@@ -96,6 +96,8 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
assert len(src_tensors) > 0
|
||||
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
|
||||
|
||||
# job_id -> event
|
||||
self._transfer_events: dict[int, torch.Event] = {}
|
||||
# queue of transfers (job_id, stream, event)
|
||||
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
|
||||
# list of CUDA streams available for re-use
|
||||
@@ -152,6 +154,7 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
ops.swap_blocks(src_tensor, dst_tensor, src_to_dst_tensor)
|
||||
event.record(stream)
|
||||
|
||||
self._transfer_events[job_id] = event
|
||||
self._transfers.append((job_id, stream, event))
|
||||
|
||||
# success
|
||||
@@ -164,8 +167,15 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
results.append((job_id, True))
|
||||
self._stream_pool.append(stream)
|
||||
self._event_pool.append(event)
|
||||
del self._transfer_events[job_id]
|
||||
return results
|
||||
|
||||
def wait(self, job_ids: set[int]):
|
||||
for job_id in job_ids:
|
||||
event = self._transfer_events.get(job_id)
|
||||
if event is not None:
|
||||
event.synchronize()
|
||||
|
||||
|
||||
class CpuGpuOffloadingHandlers:
|
||||
def __init__(
|
||||
|
||||
@@ -53,6 +53,15 @@ class OffloadingHandler(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, job_ids: set[int]) -> None:
|
||||
"""
|
||||
Wait for jobs to finish (blocking).
|
||||
|
||||
Args:
|
||||
job_ids: The set of job IDs to wait for.
|
||||
"""
|
||||
|
||||
|
||||
class OffloadingWorker:
|
||||
"""
|
||||
@@ -142,3 +151,13 @@ class OffloadingWorker:
|
||||
for handler in self.handlers:
|
||||
finished.extend(handler.get_finished())
|
||||
return finished
|
||||
|
||||
def wait(self, job_ids: set[int]) -> None:
|
||||
"""
|
||||
Wait for jobs to finish (blocking).
|
||||
|
||||
Args:
|
||||
job_ids: The set of job IDs to wait for.
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
handler.wait(job_ids)
|
||||
|
||||
@@ -3112,6 +3112,11 @@ class GPUModelRunner(
|
||||
"after execute_model() returns None."
|
||||
)
|
||||
|
||||
if scheduler_output.preempted_req_ids and has_kv_transfer_group():
|
||||
get_kv_transfer_group().handle_preemptions(
|
||||
scheduler_output.preempted_req_ids
|
||||
)
|
||||
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
with (
|
||||
record_function_or_nullcontext("gpu_model_runner: preprocess"),
|
||||
|
||||
Reference in New Issue
Block a user