OffloadingConnector: Prevent redundant loads (#29087)
Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
@@ -213,7 +213,6 @@ class RequestRunner:
|
||||
)
|
||||
|
||||
def new_request(self, token_ids: list[int]):
|
||||
assert not self.scheduler.requests
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
@@ -338,11 +337,20 @@ class RequestRunner:
|
||||
token_id=token_id or 0,
|
||||
)
|
||||
|
||||
prev_token_id = token_id
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output, model_runner_output)
|
||||
|
||||
if (
|
||||
prev_token_id is EOS_TOKEN_ID
|
||||
and prev_token_id != token_id
|
||||
and self.scheduler.requests
|
||||
):
|
||||
# continue for one more step to allow offloading to kick off
|
||||
continue
|
||||
|
||||
if token_id is None:
|
||||
break
|
||||
|
||||
@@ -651,3 +659,61 @@ def test_request_preemption(request_runner):
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_stored_gpu_block_indexes=(9, 10, 11),
|
||||
)
|
||||
|
||||
|
||||
def test_concurrent_lookups_of_the_same_prefix(request_runner):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
|
||||
runner = request_runner(
|
||||
offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
)
|
||||
|
||||
# store 1 blocks
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_stored_gpu_block_indexes=(0, 1, 2),
|
||||
)
|
||||
|
||||
# start a request to load the first block, but don't complete
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(
|
||||
decoded_tokens=[],
|
||||
complete_transfers=False,
|
||||
)
|
||||
|
||||
# request triggered a load
|
||||
transfer_jobs = list(runner.offloading_spec.handler.transfer_specs)
|
||||
assert transfer_jobs
|
||||
|
||||
# start a new request to load the same first block
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(
|
||||
decoded_tokens=[],
|
||||
complete_transfers=False,
|
||||
)
|
||||
|
||||
# request did not trigger a load
|
||||
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
|
||||
|
||||
# complete transfers
|
||||
runner.manager.prepare_store.side_effect = (
|
||||
lambda block_hashes: generate_store_output([])
|
||||
)
|
||||
runner.run(
|
||||
decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(0, 1, 2),
|
||||
)
|
||||
|
||||
# second request will use the GPU prefix cache
|
||||
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
|
||||
|
||||
@@ -107,7 +107,7 @@ class OffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request", num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
) -> tuple[int | None, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens
|
||||
@@ -161,6 +161,11 @@ class OffloadingConnectorScheduler:
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
# if GPU prefix caching is enabled,
|
||||
# track loaded blocks to avoid redundant loads
|
||||
self._blocks_being_loaded: set[BlockHash] | None = (
|
||||
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
|
||||
)
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
@@ -181,7 +186,7 @@ class OffloadingConnectorScheduler:
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request, num_computed_tokens: int
|
||||
) -> tuple[int, bool]:
|
||||
) -> tuple[int | None, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
@@ -195,6 +200,9 @@ class OffloadingConnectorScheduler:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
If None, it means that the connector needs more time to
|
||||
determine the number of matched tokens, and the scheduler
|
||||
should query for this request again later.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
@@ -214,6 +222,9 @@ class OffloadingConnectorScheduler:
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx)
|
||||
)
|
||||
if hits is None:
|
||||
# indicates a lookup that should be tried later
|
||||
return None, False
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
@@ -229,6 +240,22 @@ class OffloadingConnectorScheduler:
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
if self._blocks_being_loaded:
|
||||
block_hashes = self._get_block_hashes(
|
||||
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
|
||||
)
|
||||
|
||||
if any(
|
||||
block_hash in self._blocks_being_loaded for block_hash in block_hashes
|
||||
):
|
||||
# hit blocks are being loaded, delay request
|
||||
logger.debug(
|
||||
"Delaying request %s since some of its blocks are already"
|
||||
" being loaded",
|
||||
request.request_id,
|
||||
)
|
||||
return None, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(
|
||||
@@ -270,9 +297,13 @@ class OffloadingConnectorScheduler:
|
||||
)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_loaded[request.request_id].update(block_hashes)
|
||||
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
|
||||
req_blocks_being_loaded.update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
if self._blocks_being_loaded is not None:
|
||||
self._blocks_being_loaded.update(req_blocks_being_loaded)
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
@@ -379,6 +410,8 @@ class OffloadingConnectorScheduler:
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
if self._blocks_being_loaded:
|
||||
self._blocks_being_loaded.difference_update(block_hashes)
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
|
||||
@@ -68,7 +68,7 @@ class OffloadingEvent:
|
||||
|
||||
class OffloadingManager(ABC):
|
||||
@abstractmethod
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
|
||||
"""
|
||||
Finds the length of the maximal series of blocks, starting from the
|
||||
first one, that are all offloaded.
|
||||
@@ -78,7 +78,9 @@ class OffloadingManager(ABC):
|
||||
|
||||
Returns:
|
||||
An integer representing the maximal number of blocks that
|
||||
are currently offloaded.
|
||||
are currently offloaded, or None if the lookup should be retried
|
||||
later. Returning None will delay the request handling by the vLLM
|
||||
scheduler.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class ARCOffloadingManager(OffloadingManager):
|
||||
self.events: list[OffloadingEvent] | None = [] if enable_events else None
|
||||
self.cache_capacity: int = self.backend.get_num_free_blocks()
|
||||
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
|
||||
hit_count = 0
|
||||
for block_hash in block_hashes:
|
||||
block = self.t1.get(block_hash) or self.t2.get(block_hash)
|
||||
|
||||
@@ -24,7 +24,7 @@ class LRUOffloadingManager(OffloadingManager):
|
||||
self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
|
||||
self.events: list[OffloadingEvent] | None = [] if enable_events else None
|
||||
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
|
||||
def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
|
||||
hit_count = 0
|
||||
for block_hash in block_hashes:
|
||||
block = self.blocks.get(block_hash)
|
||||
|
||||
Reference in New Issue
Block a user