[BugFix] Wait for compute before offloading KV to CPU (#31341)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-11 00:25:08 +02:00
committed by GitHub
parent 8020a60402
commit 2a4dbe24ea
3 changed files with 44 additions and 30 deletions

View File

@@ -172,6 +172,7 @@ class RequestRunner:
self.pending_loads_count: int = 0
self.pending_stores_count: int = 0
self.unsubmitted_stores_count = 0
self.completed_loads: list[TransferSummary] = []
self.completed_stores: list[TransferSummary] = []
@@ -279,7 +280,9 @@ class RequestRunner:
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
self.pending_stores_count += self.unsubmitted_stores_count
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store)
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
self.worker_connector.start_load_kv(self._dummy_ctx)
@@ -414,10 +417,13 @@ def test_offloading_connector(request_runner):
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
)
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
runner.run(decoded_tokens=[0])
# add block missing 1 token -> no offload
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
runner.run(
decoded_tokens=[0] * (offloaded_block_size - 1),
expected_stored_gpu_block_indexes=(3, 4, 5),
)
runner.manager.prepare_store.assert_not_called()
# +1 token -> single block, fail prepare_store
@@ -435,23 +441,20 @@ def test_offloading_connector(request_runner):
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[0] * offloaded_block_size,
expected_stored_gpu_block_indexes=(15, 16, 17),
)
runner.run(decoded_tokens=[0] * offloaded_block_size)
runner.manager.touch.assert_called()
block_hashes1 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes1) == 6
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(15, 16, 17),
)
# create a new request differing only on the last token
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
runner.run(
decoded_tokens=[0],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
runner.run(decoded_tokens=[0])
runner.manager.touch.assert_called()
block_hashes2 = list(runner.manager.touch.call_args.args[0])
assert len(block_hashes2) == 6
@@ -461,7 +464,10 @@ def test_offloading_connector(request_runner):
assert block_hashes1[5] != block_hashes2[5]
# terminate request
runner.run(decoded_tokens=[EOS_TOKEN_ID])
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
)
# full_block_tokens - num_computed_tokens < offloaded_block_size
runner.new_request(

View File

@@ -78,7 +78,7 @@ class OffloadingConnector(KVConnectorBase_V1):
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
self.connector_worker.start_kv_transfers(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
pass
@@ -95,7 +95,7 @@ class OffloadingConnector(KVConnectorBase_V1):
def wait_for_save(self):
assert self.connector_worker is not None
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
self.connector_worker.start_store_kv(self._connector_metadata)
self.connector_worker.prepare_store_kv(self._connector_metadata)
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
assert self.connector_worker is not None
@@ -427,6 +427,8 @@ class OffloadingConnectorWorker:
self._load_job: dict[ReqId, int] = {}
# req_id -> set(active job IDs)
self._store_jobs = defaultdict[ReqId, set[int]](set)
# list of store jobs pending submission (job_id, transfer_spec)
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
self._finished_reqs_waiting_for_store: set[ReqId] = set()
@@ -464,20 +466,29 @@ class OffloadingConnectorWorker:
attn_backends = {cross_layer_name: attn_backend}
self._register_handlers(kv_caches, attn_backends)
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
for job_id, transfer_spec in self._unsubmitted_store_jobs:
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
self._unsubmitted_store_jobs.clear()
for req_id, transfer_spec in metadata.reqs_to_load.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, False)
assert req_id not in self._load_job
self._load_job[req_id] = job_id
assert self.worker.transfer_async(job_id, transfer_spec)
success = self.worker.transfer_async(job_id, transfer_spec)
assert success
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
for req_id, transfer_spec in metadata.reqs_to_store.items():
job_id = self._generate_job_id()
self._jobs[job_id] = (req_id, True)
self._store_jobs[req_id].add(job_id)
assert self.worker.transfer_async(job_id, transfer_spec)
# NOTE(orozery): defer the store to the beginning of the next engine step,
# so that offloading starts AFTER transfers related to token sampling,
# thereby avoiding delays to token generation due to offloading.
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""

View File

@@ -68,7 +68,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
kv_dim_before_num_blocks: list[bool],
src_block_size_factor: int,
dst_block_size_factor: int,
priority: int,
):
"""
Initialize a SingleDirectionOffloadingHandler.
@@ -85,8 +84,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
per KV block in a source tensor.
dst_block_size_factor: The number of kernel blocks
per KV block in a destination tensor.
priority: The priority of the backing CUDA streams.
Lower numbers indicate higher priority.
"""
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
@@ -95,7 +92,9 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
self.src_block_size_factor: int = src_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor
self.priority = priority
assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
# queue of transfers (job_id, stream, event)
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
@@ -130,12 +129,12 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
src_to_dst_tensor = torch.from_numpy(src_to_dst)
stream = (
self._stream_pool.pop()
if self._stream_pool
else torch.cuda.Stream(priority=self.priority)
)
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
event = self._event_pool.pop() if self._event_pool else torch.Event()
if self.gpu_to_cpu:
# wait for model computation to finish before offloading
stream.wait_stream(torch.cuda.current_stream())
if self._transfers:
_, _, last_event = self._transfers[-1]
# assure job will start only after the previous one completes
@@ -267,7 +266,6 @@ class CpuGpuOffloadingHandlers:
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=gpu_block_size_factor,
dst_block_size_factor=cpu_block_size_factor,
priority=1,
)
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
@@ -276,5 +274,4 @@ class CpuGpuOffloadingHandlers:
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=cpu_block_size_factor,
dst_block_size_factor=gpu_block_size_factor,
priority=-1,
)