From 2a4dbe24eadcb8e0354e47f608b53399aec52c43 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Sun, 11 Jan 2026 00:25:08 +0200 Subject: [PATCH] [BugFix] Wait for compute before offloading KV to CPU (#31341) Signed-off-by: Or Ozeri --- .../unit/test_offloading_connector.py | 32 +++++++++++-------- .../kv_connector/v1/offloading_connector.py | 23 +++++++++---- vllm/v1/kv_offload/worker/cpu_gpu.py | 19 +++++------ 3 files changed, 44 insertions(+), 30 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index e1ea0b298..4b0e69406 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -172,6 +172,7 @@ class RequestRunner: self.pending_loads_count: int = 0 self.pending_stores_count: int = 0 + self.unsubmitted_stores_count = 0 self.completed_loads: list[TransferSummary] = [] self.completed_stores: list[TransferSummary] = [] @@ -279,7 +280,9 @@ class RequestRunner: assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata) self.pending_loads_count += len(kv_connector_metadata.reqs_to_load) - self.pending_stores_count += len(kv_connector_metadata.reqs_to_store) + + self.pending_stores_count += self.unsubmitted_stores_count + self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store) self.worker_connector.bind_connector_metadata(kv_connector_metadata) self.worker_connector.start_load_kv(self._dummy_ctx) @@ -414,10 +417,13 @@ def test_offloading_connector(request_runner): runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output(list(block_hashes)[1:2]) ) - runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5)) + runner.run(decoded_tokens=[0]) # add block missing 1 token -> no offload - runner.run(decoded_tokens=[0] * (offloaded_block_size - 1)) + runner.run( + decoded_tokens=[0] * (offloaded_block_size - 1), + expected_stored_gpu_block_indexes=(3, 4, 5), + ) runner.manager.prepare_store.assert_not_called() # +1 token -> single block, fail prepare_store @@ -435,23 +441,20 @@ def test_offloading_connector(request_runner): runner.manager.prepare_store.side_effect = ( lambda block_hashes: generate_store_output(block_hashes) ) - runner.run( - decoded_tokens=[0] * offloaded_block_size, - expected_stored_gpu_block_indexes=(15, 16, 17), - ) + runner.run(decoded_tokens=[0] * offloaded_block_size) runner.manager.touch.assert_called() block_hashes1 = list(runner.manager.touch.call_args.args[0]) assert len(block_hashes1) == 6 # terminate request - runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], + expected_stored_gpu_block_indexes=(15, 16, 17), + ) # create a new request differing only on the last token runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1]) - runner.run( - decoded_tokens=[0], - expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)), - ) + runner.run(decoded_tokens=[0]) runner.manager.touch.assert_called() block_hashes2 = list(runner.manager.touch.call_args.args[0]) assert len(block_hashes2) == 6 @@ -461,7 +464,10 @@ def test_offloading_connector(request_runner): assert block_hashes1[5] != block_hashes2[5] # terminate request - runner.run(decoded_tokens=[EOS_TOKEN_ID]) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], + expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)), + ) # full_block_tokens - num_computed_tokens < offloaded_block_size runner.new_request( diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 67cf4b047..17e084942 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -78,7 +78,7 @@ class OffloadingConnector(KVConnectorBase_V1): def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None: assert self.connector_worker is not None assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) - self.connector_worker.start_load_kv(self._connector_metadata) + self.connector_worker.start_kv_transfers(self._connector_metadata) def wait_for_layer_load(self, layer_name: str) -> None: pass @@ -95,7 +95,7 @@ class OffloadingConnector(KVConnectorBase_V1): def wait_for_save(self): assert self.connector_worker is not None assert isinstance(self._connector_metadata, OffloadingConnectorMetadata) - self.connector_worker.start_store_kv(self._connector_metadata) + self.connector_worker.prepare_store_kv(self._connector_metadata) def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: assert self.connector_worker is not None @@ -427,6 +427,8 @@ class OffloadingConnectorWorker: self._load_job: dict[ReqId, int] = {} # req_id -> set(active job IDs) self._store_jobs = defaultdict[ReqId, set[int]](set) + # list of store jobs pending submission (job_id, transfer_spec) + self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = [] self._finished_reqs_waiting_for_store: set[ReqId] = set() @@ -464,20 +466,29 @@ class OffloadingConnectorWorker: attn_backends = {cross_layer_name: attn_backend} self._register_handlers(kv_caches, attn_backends) - def start_load_kv(self, metadata: OffloadingConnectorMetadata): + def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): + for job_id, transfer_spec in self._unsubmitted_store_jobs: + success = self.worker.transfer_async(job_id, transfer_spec) + assert success + self._unsubmitted_store_jobs.clear() + for req_id, transfer_spec in metadata.reqs_to_load.items(): job_id = self._generate_job_id() self._jobs[job_id] = (req_id, False) assert req_id not in self._load_job self._load_job[req_id] = job_id - assert self.worker.transfer_async(job_id, transfer_spec) + success = self.worker.transfer_async(job_id, transfer_spec) + assert success - def start_store_kv(self, metadata: OffloadingConnectorMetadata): + def prepare_store_kv(self, metadata: OffloadingConnectorMetadata): for req_id, transfer_spec in metadata.reqs_to_store.items(): job_id = self._generate_job_id() self._jobs[job_id] = (req_id, True) self._store_jobs[req_id].add(job_id) - assert self.worker.transfer_async(job_id, transfer_spec) + # NOTE(orozery): defer the store to the beginning of the next engine step, + # so that offloading starts AFTER transfers related to token sampling, + # thereby avoiding delays to token generation due to offloading. + self._unsubmitted_store_jobs.append((job_id, transfer_spec)) def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """ diff --git a/vllm/v1/kv_offload/worker/cpu_gpu.py b/vllm/v1/kv_offload/worker/cpu_gpu.py index dcaecb099..e774dc4c5 100644 --- a/vllm/v1/kv_offload/worker/cpu_gpu.py +++ b/vllm/v1/kv_offload/worker/cpu_gpu.py @@ -68,7 +68,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): kv_dim_before_num_blocks: list[bool], src_block_size_factor: int, dst_block_size_factor: int, - priority: int, ): """ Initialize a SingleDirectionOffloadingHandler. @@ -85,8 +84,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): per KV block in a source tensor. dst_block_size_factor: The number of kernel blocks per KV block in a destination tensor. - priority: The priority of the backing CUDA streams. - Lower numbers indicate higher priority. """ assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks) @@ -95,7 +92,9 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks self.src_block_size_factor: int = src_block_size_factor self.dst_block_size_factor: int = dst_block_size_factor - self.priority = priority + + assert len(src_tensors) > 0 + self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda # queue of transfers (job_id, stream, event) self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque() @@ -130,12 +129,12 @@ class SingleDirectionOffloadingHandler(OffloadingHandler): expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1]) src_to_dst_tensor = torch.from_numpy(src_to_dst) - stream = ( - self._stream_pool.pop() - if self._stream_pool - else torch.cuda.Stream(priority=self.priority) - ) + stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream() event = self._event_pool.pop() if self._event_pool else torch.Event() + + if self.gpu_to_cpu: + # wait for model computation to finish before offloading + stream.wait_stream(torch.cuda.current_stream()) if self._transfers: _, _, last_event = self._transfers[-1] # assure job will start only after the previous one completes @@ -267,7 +266,6 @@ class CpuGpuOffloadingHandlers: kv_dim_before_num_blocks=kv_dim_before_num_blocks, src_block_size_factor=gpu_block_size_factor, dst_block_size_factor=cpu_block_size_factor, - priority=1, ) self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler( @@ -276,5 +274,4 @@ class CpuGpuOffloadingHandlers: kv_dim_before_num_blocks=kv_dim_before_num_blocks, src_block_size_factor=cpu_block_size_factor, dst_block_size_factor=gpu_block_size_factor, - priority=-1, )