[BugFix] Wait for compute before offloading KV to CPU (#31341)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -172,6 +172,7 @@ class RequestRunner:
|
||||
|
||||
self.pending_loads_count: int = 0
|
||||
self.pending_stores_count: int = 0
|
||||
self.unsubmitted_stores_count = 0
|
||||
|
||||
self.completed_loads: list[TransferSummary] = []
|
||||
self.completed_stores: list[TransferSummary] = []
|
||||
@@ -279,7 +280,9 @@ class RequestRunner:
|
||||
assert isinstance(kv_connector_metadata, OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
self.pending_stores_count += len(kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.pending_stores_count += self.unsubmitted_stores_count
|
||||
self.unsubmitted_stores_count = len(kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
@@ -414,10 +417,13 @@ def test_offloading_connector(request_runner):
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
|
||||
)
|
||||
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
|
||||
runner.run(decoded_tokens=[0])
|
||||
|
||||
# add block missing 1 token -> no offload
|
||||
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
|
||||
runner.run(
|
||||
decoded_tokens=[0] * (offloaded_block_size - 1),
|
||||
expected_stored_gpu_block_indexes=(3, 4, 5),
|
||||
)
|
||||
runner.manager.prepare_store.assert_not_called()
|
||||
|
||||
# +1 token -> single block, fail prepare_store
|
||||
@@ -435,23 +441,20 @@ def test_offloading_connector(request_runner):
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17),
|
||||
)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17),
|
||||
)
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
runner.run(
|
||||
decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
|
||||
)
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes2 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes2) == 6
|
||||
@@ -461,7 +464,10 @@ def test_offloading_connector(request_runner):
|
||||
assert block_hashes1[5] != block_hashes2[5]
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_stored_gpu_block_indexes=tuple(range(6 * block_size_factor)),
|
||||
)
|
||||
|
||||
# full_block_tokens - num_computed_tokens < offloaded_block_size
|
||||
runner.new_request(
|
||||
|
||||
@@ -78,7 +78,7 @@ class OffloadingConnector(KVConnectorBase_V1):
|
||||
def start_load_kv(self, forward_context: "ForwardContext", **kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
self.connector_worker.start_kv_transfers(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
@@ -95,7 +95,7 @@ class OffloadingConnector(KVConnectorBase_V1):
|
||||
def wait_for_save(self):
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata, OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_store_kv(self._connector_metadata)
|
||||
self.connector_worker.prepare_store_kv(self._connector_metadata)
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
assert self.connector_worker is not None
|
||||
@@ -427,6 +427,8 @@ class OffloadingConnectorWorker:
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
# list of store jobs pending submission (job_id, transfer_spec)
|
||||
self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = []
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
@@ -464,20 +466,29 @@ class OffloadingConnectorWorker:
|
||||
attn_backends = {cross_layer_name: attn_backend}
|
||||
self._register_handlers(kv_caches, attn_backends)
|
||||
|
||||
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
def start_kv_transfers(self, metadata: OffloadingConnectorMetadata):
|
||||
for job_id, transfer_spec in self._unsubmitted_store_jobs:
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
self._unsubmitted_store_jobs.clear()
|
||||
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
success = self.worker.transfer_async(job_id, transfer_spec)
|
||||
assert success
|
||||
|
||||
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
def prepare_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
# NOTE(orozery): defer the store to the beginning of the next engine step,
|
||||
# so that offloading starts AFTER transfers related to token sampling,
|
||||
# thereby avoiding delays to token generation due to offloading.
|
||||
self._unsubmitted_store_jobs.append((job_id, transfer_spec))
|
||||
|
||||
def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
|
||||
@@ -68,7 +68,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
kv_dim_before_num_blocks: list[bool],
|
||||
src_block_size_factor: int,
|
||||
dst_block_size_factor: int,
|
||||
priority: int,
|
||||
):
|
||||
"""
|
||||
Initialize a SingleDirectionOffloadingHandler.
|
||||
@@ -85,8 +84,6 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
per KV block in a source tensor.
|
||||
dst_block_size_factor: The number of kernel blocks
|
||||
per KV block in a destination tensor.
|
||||
priority: The priority of the backing CUDA streams.
|
||||
Lower numbers indicate higher priority.
|
||||
"""
|
||||
assert len(src_tensors) == len(dst_tensors) == len(kv_dim_before_num_blocks)
|
||||
|
||||
@@ -95,7 +92,9 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
self.kv_dim_before_num_blocks: list[bool] = kv_dim_before_num_blocks
|
||||
self.src_block_size_factor: int = src_block_size_factor
|
||||
self.dst_block_size_factor: int = dst_block_size_factor
|
||||
self.priority = priority
|
||||
|
||||
assert len(src_tensors) > 0
|
||||
self.gpu_to_cpu: bool = self.src_tensors[0].is_cuda
|
||||
|
||||
# queue of transfers (job_id, stream, event)
|
||||
self._transfers: deque[tuple[int, torch.cuda.Stream, torch.Event]] = deque()
|
||||
@@ -130,12 +129,12 @@ class SingleDirectionOffloadingHandler(OffloadingHandler):
|
||||
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
|
||||
src_to_dst_tensor = torch.from_numpy(src_to_dst)
|
||||
|
||||
stream = (
|
||||
self._stream_pool.pop()
|
||||
if self._stream_pool
|
||||
else torch.cuda.Stream(priority=self.priority)
|
||||
)
|
||||
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
|
||||
event = self._event_pool.pop() if self._event_pool else torch.Event()
|
||||
|
||||
if self.gpu_to_cpu:
|
||||
# wait for model computation to finish before offloading
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
if self._transfers:
|
||||
_, _, last_event = self._transfers[-1]
|
||||
# assure job will start only after the previous one completes
|
||||
@@ -267,7 +266,6 @@ class CpuGpuOffloadingHandlers:
|
||||
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||
src_block_size_factor=gpu_block_size_factor,
|
||||
dst_block_size_factor=cpu_block_size_factor,
|
||||
priority=1,
|
||||
)
|
||||
|
||||
self.cpu_to_gpu_handler = SingleDirectionOffloadingHandler(
|
||||
@@ -276,5 +274,4 @@ class CpuGpuOffloadingHandlers:
|
||||
kv_dim_before_num_blocks=kv_dim_before_num_blocks,
|
||||
src_block_size_factor=cpu_block_size_factor,
|
||||
dst_block_size_factor=gpu_block_size_factor,
|
||||
priority=-1,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user