diff --git a/tests/v1/kv_connector/unit/test_offloading_connector.py b/tests/v1/kv_connector/unit/test_offloading_connector.py index 5c049301c..0c8a185a9 100644 --- a/tests/v1/kv_connector/unit/test_offloading_connector.py +++ b/tests/v1/kv_connector/unit/test_offloading_connector.py @@ -213,7 +213,6 @@ class RequestRunner: ) def new_request(self, token_ids: list[int]): - assert not self.scheduler.requests self.req_id += 1 req = Request( @@ -338,11 +337,20 @@ class RequestRunner: token_id=token_id or 0, ) + prev_token_id = token_id if self.scheduler.running: token_id = next(tokens_iter, None) self.scheduler.update_from_output(scheduler_output, model_runner_output) + if ( + prev_token_id is EOS_TOKEN_ID + and prev_token_id != token_id + and self.scheduler.requests + ): + # continue for one more step to allow offloading to kick off + continue + if token_id is None: break @@ -651,3 +659,61 @@ def test_request_preemption(request_runner): decoded_tokens=[EOS_TOKEN_ID], expected_stored_gpu_block_indexes=(9, 10, 11), ) + + +def test_concurrent_lookups_of_the_same_prefix(request_runner): + offloaded_block_size = 12 + gpu_block_size = 4 + num_gpu_blocks = 100 + + runner = request_runner( + offloaded_block_size=offloaded_block_size, + gpu_block_size=gpu_block_size, + num_gpu_blocks=num_gpu_blocks, + ) + + # store 1 blocks + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output(block_hashes) + ) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], + expected_stored_gpu_block_indexes=(0, 1, 2), + ) + + # start a request to load the first block, but don't complete + runner.scheduler.reset_prefix_cache() + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.lookup.return_value = 1 + runner.run( + decoded_tokens=[], + complete_transfers=False, + ) + + # request triggered a load + transfer_jobs = list(runner.offloading_spec.handler.transfer_specs) + assert transfer_jobs + + # start a new request to load the same first block + runner.new_request(token_ids=[0] * offloaded_block_size) + runner.manager.lookup.return_value = 1 + runner.run( + decoded_tokens=[], + complete_transfers=False, + ) + + # request did not trigger a load + assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs) + + # complete transfers + runner.manager.prepare_store.side_effect = ( + lambda block_hashes: generate_store_output([]) + ) + runner.run( + decoded_tokens=[EOS_TOKEN_ID], + expected_loaded_gpu_block_indexes=(0, 1, 2), + ) + + # second request will use the GPU prefix cache + assert transfer_jobs == list(runner.offloading_spec.handler.transfer_specs) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index 777c71288..5667f1d3f 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -107,7 +107,7 @@ class OffloadingConnector(KVConnectorBase_V1): def get_num_new_matched_tokens( self, request: "Request", num_computed_tokens: int - ) -> tuple[int, bool]: + ) -> tuple[int | None, bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens( request, num_computed_tokens @@ -161,6 +161,11 @@ class OffloadingConnectorScheduler: # request blocks are stored in order # index of next block (of size offloaded_block_size) to offload self._next_stored_block_idx: dict[ReqId, int] = {} + # if GPU prefix caching is enabled, + # track loaded blocks to avoid redundant loads + self._blocks_being_loaded: set[BlockHash] | None = ( + set() if spec.vllm_config.cache_config.enable_prefix_caching else None + ) # request ID -> set(block hashes being stored/load) self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set) @@ -181,7 +186,7 @@ class OffloadingConnectorScheduler: def get_num_new_matched_tokens( self, request: Request, num_computed_tokens: int - ) -> tuple[int, bool]: + ) -> tuple[int | None, bool]: """ Get number of new tokens that can be loaded beyond the num_computed_tokens. @@ -195,6 +200,9 @@ class OffloadingConnectorScheduler: A tuple with the following elements: - The number of tokens that can be loaded beyond what is already computed. + If None, it means that the connector needs more time to + determine the number of matched tokens, and the scheduler + should query for this request again later. - `True` if tokens will be loaded asynchronously (between scheduler steps). """ @@ -214,6 +222,9 @@ class OffloadingConnectorScheduler: hits = self.manager.lookup( self._get_block_hashes(request, start_idx=start_block_idx) ) + if hits is None: + # indicates a lookup that should be tried later + return None, False if hits == 0: return 0, False @@ -229,6 +240,22 @@ class OffloadingConnectorScheduler: if num_hit_tokens < self.offloaded_block_size: return 0, False + if self._blocks_being_loaded: + block_hashes = self._get_block_hashes( + request, start_idx=start_block_idx, end_idx=start_block_idx + hits + ) + + if any( + block_hash in self._blocks_being_loaded for block_hash in block_hashes + ): + # hit blocks are being loaded, delay request + logger.debug( + "Delaying request %s since some of its blocks are already" + " being loaded", + request.request_id, + ) + return None, False + return num_hit_tokens, True def update_state_after_alloc( @@ -270,9 +297,13 @@ class OffloadingConnectorScheduler: ) self._reqs_to_load[request.request_id] = (src_spec, dst_spec) - self._reqs_being_loaded[request.request_id].update(block_hashes) + req_blocks_being_loaded = self._reqs_being_loaded[request.request_id] + req_blocks_being_loaded.update(block_hashes) self._next_stored_block_idx[request.request_id] = num_blocks + if self._blocks_being_loaded is not None: + self._blocks_being_loaded.update(req_blocks_being_loaded) + def _get_reqs_to_store(self, scheduler_output: SchedulerOutput): reqs_to_store: dict[ReqId, TransferSpec] = {} # iterate over both new and cached requests @@ -379,6 +410,8 @@ class OffloadingConnectorScheduler: for req_id in connector_output.finished_recving or []: block_hashes = self._reqs_being_loaded.pop(req_id, None) if block_hashes: + if self._blocks_being_loaded: + self._blocks_being_loaded.difference_update(block_hashes) self.manager.complete_load(block_hashes) def request_finished( diff --git a/vllm/v1/kv_offload/abstract.py b/vllm/v1/kv_offload/abstract.py index c1d1cbebc..27aa1e353 100644 --- a/vllm/v1/kv_offload/abstract.py +++ b/vllm/v1/kv_offload/abstract.py @@ -68,7 +68,7 @@ class OffloadingEvent: class OffloadingManager(ABC): @abstractmethod - def lookup(self, block_hashes: Iterable[BlockHash]) -> int: + def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None: """ Finds the length of the maximal series of blocks, starting from the first one, that are all offloaded. @@ -78,7 +78,9 @@ class OffloadingManager(ABC): Returns: An integer representing the maximal number of blocks that - are currently offloaded. + are currently offloaded, or None if the lookup should be retried + later. Returning None will delay the request handling by the vLLM + scheduler. """ pass diff --git a/vllm/v1/kv_offload/arc_manager.py b/vllm/v1/kv_offload/arc_manager.py index 45b166d67..7f2246a69 100644 --- a/vllm/v1/kv_offload/arc_manager.py +++ b/vllm/v1/kv_offload/arc_manager.py @@ -63,7 +63,7 @@ class ARCOffloadingManager(OffloadingManager): self.events: list[OffloadingEvent] | None = [] if enable_events else None self.cache_capacity: int = self.backend.get_num_free_blocks() - def lookup(self, block_hashes: Iterable[BlockHash]) -> int: + def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None: hit_count = 0 for block_hash in block_hashes: block = self.t1.get(block_hash) or self.t2.get(block_hash) diff --git a/vllm/v1/kv_offload/lru_manager.py b/vllm/v1/kv_offload/lru_manager.py index 0a0111f88..ff9a38c53 100644 --- a/vllm/v1/kv_offload/lru_manager.py +++ b/vllm/v1/kv_offload/lru_manager.py @@ -24,7 +24,7 @@ class LRUOffloadingManager(OffloadingManager): self.blocks: OrderedDict[BlockHash, BlockStatus] = OrderedDict() self.events: list[OffloadingEvent] | None = [] if enable_events else None - def lookup(self, block_hashes: Iterable[BlockHash]) -> int: + def lookup(self, block_hashes: Iterable[BlockHash]) -> int | None: hit_count = 0 for block_hash in block_hashes: block = self.blocks.get(block_hash)