[BugFix] Wait for compute before offloading KV to CPU (#31341)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-11 00:25:08 +02:00
committed by GitHub
parent 8020a60402
commit 2a4dbe24ea
3 changed files with 44 additions and 30 deletions

View File

@@ -68,7 +68,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
kv_dim_before_num_blocks: list[bool],
src_block_size_factor: int,
dst_block_size_factor: int,
priority: int,
):
"""
Initialize a SingleDirectionOffloadingHandler.
@@ -85,8 +84,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
per KV block in a source tensor.
dst_block_size_factor: The number of kernel blocks
per KV block in a destination tensor.
priority: The priority of the backing CUDA streams.
Lower numbers indicate higher priority.
"""
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
@@ -95,7 +92,9 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
self.src_block_size_factor: int = src_block_size_factor
self.dst_block_size_factor: int = dst_block_size_factor
self.priority = priority
assert len(src_tensors) > 0
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
# queue of transfers (job_id, stream, event)
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
@@ -130,12 +129,12 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
src_to_dst_tensor = torch.from_numpy(src_to_dst)
stream = (
self._stream_pool.pop()
if self._stream_pool
else torch.cuda.Stream(priority=self.priority)
)
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
event = self._event_pool.pop() if self._event_pool else torch.Event()
if self.gpu_to_cpu:
# wait for model computation to finish before offloading
stream.wait_stream(torch.cuda.current_stream())
if self._transfers:
_, _, last_event = self._transfers[-1]
# assure job will start only after the previous one completes
@@ -267,7 +266,6 @@ class CpuGpuOffloadingHandlers:
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=gpu_block_size_factor,
dst_block_size_factor=cpu_block_size_factor,
priority=1,
)
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
@@ -276,5 +274,4 @@ class CpuGpuOffloadingHandlers:
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
src_block_size_factor=cpu_block_size_factor,
dst_block_size_factor=gpu_block_size_factor,
priority=-1,
)