OffloadingConnector: Prevent redundant loads (#29087)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri
2026-01-21 03:15:42 +02:00
committed by GitHub
parent c78ee240b3
commit 7013e9ac8f
5 changed files with 109 additions and 8 deletions

View File

@@ -213,7 +213,6 @@ class RequestRunner:
)
def new_request(self, token_ids: list[int]):
assert not self.scheduler.requests
self.req_id += 1
req = Request(
@@ -338,11 +337,20 @@ class RequestRunner:
token_id=token_id or 0,
)
prev_token_id = token_id
if self.scheduler.running:
token_id = next(tokens_iter, None)
self.scheduler.update_from_output(scheduler_output, model_runner_output)
if (
prev_token_id is EOS_TOKEN_ID
and prev_token_id != token_id
and self.scheduler.requests
):
# continue for one more step to allow offloading to kick off
continue
if token_id is None:
break
@@ -651,3 +659,61 @@ def test_request_preemption(request_runner):
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(9, 10, 11),
)
def test_concurrent_lookups_of_the_same_prefix(request_runner):
offloaded_block_size = 12
gpu_block_size = 4
num_gpu_blocks = 100
runner = request_runner(
offloaded_block_size=offloaded_block_size,
gpu_block_size=gpu_block_size,
num_gpu_blocks=num_gpu_blocks,
)
# store 1 blocks
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output(block_hashes)
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_stored_gpu_block_indexes=(0, 1, 2),
)
# start a request to load the first block, but don't complete
runner.scheduler.reset_prefix_cache()
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
)
# request triggered a load
transfer_jobs = list(runner.offloading_spec.handler.transfer_specs)
assert transfer_jobs
# start a new request to load the same first block
runner.new_request(token_ids=[0] * offloaded_block_size)
runner.manager.lookup.return_value = 1
runner.run(
decoded_tokens=[],
complete_transfers=False,
)
# request did not trigger a load
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)
# complete transfers
runner.manager.prepare_store.side_effect = (
lambda block_hashes: generate_store_output([])
)
runner.run(
decoded_tokens=[EOS_TOKEN_ID],
expected_loaded_gpu_block_indexes=(0, 1, 2),
)
# second request will use the GPU prefix cache
assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs)

View File

@@ -107,7 +107,7 @@ class OffloadingConnector(KVConnectorBase_V1):
def get_num_new_matched_tokens(
self, request: "Request", num_computed_tokens: int
) -> tuple[int, bool]:
) -> tuple[int | None, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens
@@ -161,6 +161,11 @@ class OffloadingConnectorScheduler:
# request blocks are stored in order
# index of next block (of size offloaded_block_size) to offload
self._next_stored_block_idx: dict[ReqId, int] = {}
# if GPU prefix caching is enabled,
# track loaded blocks to avoid redundant loads
self._blocks_being_loaded: set[BlockHash] | None = (
set() if spec.vllm_config.cache_config.enable_prefix_caching else None
)
# request ID -> set(block hashes being stored/load)
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
@@ -181,7 +186,7 @@ class OffloadingConnectorScheduler:
def get_num_new_matched_tokens(
self, request: Request, num_computed_tokens: int
) -> tuple[int, bool]:
) -> tuple[int | None, bool]:
"""
Get number of new tokens that can be loaded beyond the
num_computed_tokens.
@@ -195,6 +200,9 @@ class OffloadingConnectorScheduler:
A tuple with the following elements:
- The number of tokens that can be loaded beyond what is
already computed.
If None, it means that the connector needs more time to
determine the number of matched tokens, and the scheduler
should query for this request again later.
- `True` if tokens will be loaded asynchronously
(between scheduler steps).
"""
@@ -214,6 +222,9 @@ class OffloadingConnectorScheduler:
hits = self.manager.lookup(
self._get_block_hashes(request, start_idx=start_block_idx)
)
if hits is None:
# indicates a lookup that should be tried later
return None, False
if hits == 0:
return 0, False
@@ -229,6 +240,22 @@ class OffloadingConnectorScheduler:
if num_hit_tokens < self.offloaded_block_size:
return 0, False
if self._blocks_being_loaded:
block_hashes = self._get_block_hashes(
request, start_idx=start_block_idx, end_idx=start_block_idx + hits
)
if any(
block_hash in self._blocks_being_loaded for block_hash in block_hashes
):
# hit blocks are being loaded, delay request
logger.debug(
"Delaying request %s since some of its blocks are already"
" being loaded",
request.request_id,
)
return None, False
return num_hit_tokens, True
def update_state_after_alloc(
@@ -270,9 +297,13 @@ class OffloadingConnectorScheduler:
)
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
self._reqs_being_loaded[request.request_id].update(block_hashes)
req_blocks_being_loaded = self._reqs_being_loaded[request.request_id]
req_blocks_being_loaded.update(block_hashes)
self._next_stored_block_idx[request.request_id] = num_blocks
if self._blocks_being_loaded is not None:
self._blocks_being_loaded.update(req_blocks_being_loaded)
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
reqs_to_store: dict[ReqId, TransferSpec] = {}
# iterate over both new and cached requests
@@ -379,6 +410,8 @@ class OffloadingConnectorScheduler:
for req_id in connector_output.finished_recving or []:
block_hashes = self._reqs_being_loaded.pop(req_id, None)
if block_hashes:
if self._blocks_being_loaded:
self._blocks_being_loaded.difference_update(block_hashes)
self.manager.complete_load(block_hashes)
def request_finished(

View File

@@ -68,7 +68,7 @@ class OffloadingEvent:
class OffloadingManager(ABC):
@abstractmethod
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
"""
Finds the length of the maximal series of blocks, starting from the
first one, that are all offloaded.
@@ -78,7 +78,9 @@ class OffloadingManager(ABC):
Returns:
An integer representing the maximal number of blocks that
are currently offloaded.
are currently offloaded, or None if the lookup should be retried
later. Returning None will delay the request handling by the vLLM
scheduler.
"""
pass

View File

@@ -63,7 +63,7 @@ class ARCOffloadingManager(OffloadingManager):
self.events: list[OffloadingEvent] | None = [] if enable_events else None
self.cache_capacity: int = self.backend.get_num_free_blocks()
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
hit_count = 0
for block_hash in block_hashes:
block = self.t1.get(block_hash) or self.t2.get(block_hash)

View File

@@ -24,7 +24,7 @@ class LRUOffloadingManager(OffloadingManager):
self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict()
self.events: list[OffloadingEvent] | None = [] if enable_events else None
def lookup(self, block_hashes: Iterable[BlockHash]) -> int:
def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None:
hit_count = 0
for block_hash in block_hashes:
block = self.blocks.get(block_hash)