From 378385b90cddbe8cbc6e51d4ed59ce83e499530a Mon Sep 17 00:00:00 2001 From: knlnguyen1802 Date: Thu, 22 Jan 2026 11:30:59 +0800 Subject: [PATCH] [EC Connector] Optimize remote cache check in scheduler (#32585) Signed-off-by: knlnguyen1802 --- tests/v1/core/test_scheduler.py | 40 +++++++++------- .../unit/test_ec_example_connector.py | 48 ++++++++++++------- .../ec_transfer/ec_connector/base.py | 14 +++--- .../ec_connector/example_connector.py | 17 +++---- vllm/v1/core/sched/scheduler.py | 14 +++--- 5 files changed, 76 insertions(+), 57 deletions(-) diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index d1b4b5596..e5fcdf518 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -2560,15 +2560,14 @@ def test_ec_connector_cache_hit_external_load(use_kv_connector): mm_positions=mm_positions, )[0] - # Mock cache hit - encoder cache exists externally - scheduler.ec_connector.has_caches = Mock(return_value=[True]) + # Mock cache hit - encoder cache has_exists externally + scheduler.ec_connector.has_cache_item = Mock(return_value=True) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) scheduler.add_request(request) output = scheduler.schedule() - # Should schedule prompt tokens scheduled_tokens = output.num_scheduled_tokens[request.request_id] assert scheduled_tokens == NUM_TOKENS @@ -2611,7 +2610,7 @@ def test_ec_connector_cache_miss_computes_locally(use_kv_connector): )[0] # Mock cache miss - encoder cache doesn't exist externally - scheduler.ec_connector.has_caches = Mock(return_value=[False]) + scheduler.ec_connector.has_cache_item = Mock(return_value=False) scheduler.add_request(request_mm_missed) output = scheduler.schedule() @@ -2665,7 +2664,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1), ] ] - + has_cache_item_result_map_1 = {"hash1_A": False, "hash1_B": True, "hash1_F": True} # Create request with 4 MM items, with 2 identical items request1 = create_requests( num_requests=1, @@ -2676,7 +2675,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): )[0] # Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist - scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True]) + scheduler.ec_connector.has_cache_item = Mock( + side_effect=lambda hash_val: has_cache_item_result_map_1[hash_val] + ) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) @@ -2736,7 +2737,12 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2), ] ] - + has_cache_item_result_map_2 = { + "hash1_C": True, + "hash1_D": False, + "hash1_E": False, + "hash1_A": True, + } request2 = create_requests( num_requests=1, num_tokens=NUM_TOKENS_2, @@ -2746,7 +2752,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector): )[0] # Mock partial cache hit: only hash1_A and hash1_C exist in connector - scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True]) + scheduler.ec_connector.has_cache_item = Mock( + side_effect=lambda hash_val: has_cache_item_result_map_2[hash_val] + ) scheduler.add_request(request2) output = scheduler.schedule() @@ -2821,9 +2829,9 @@ def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector): if cache_exist == "connector_only": # Cache exist in ec_connector - scheduler.ec_connector.has_caches = Mock(return_value=[True]) + scheduler.ec_connector.has_cache_item = Mock(return_value=True) elif cache_exist == "no_where": - scheduler.ec_connector.has_caches = Mock(return_value=[False]) + scheduler.ec_connector.has_cache_item = Mock(return_value=False) output = scheduler.schedule() assert len(output.scheduled_new_reqs) == len(requests) @@ -2887,7 +2895,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector): ) # Mock ec_connector load external cache behavior - scheduler.ec_connector.has_caches = Mock(return_value=[True]) + scheduler.ec_connector.has_cache_item = Mock(return_value=True) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) @@ -2984,7 +2992,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( ) # Mock cache hit: Both cache exist in connector (at E->PD initially) - scheduler.ec_connector.has_caches = Mock(return_value=[True]) + scheduler.ec_connector.has_cache_item = Mock(return_value=True) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc ) @@ -3139,9 +3147,9 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption( if cache_exist == "connector_only": # Cache exist in ec_connector - scheduler.ec_connector.has_caches = Mock(return_value=[True]) + scheduler.ec_connector.has_cache_item = Mock(return_value=True) elif cache_exist == "no_where": - scheduler.ec_connector.has_caches = Mock(return_value=[False]) + scheduler.ec_connector.has_cache_item = Mock(return_value=False) # 4th Schedule - this should trigger req_low resumption from waiting output = scheduler.schedule() @@ -3259,8 +3267,8 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto )[0] # Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely - scheduler.ec_connector.has_caches = Mock( - side_effect=lambda req: [True, True, True] if req == request2 else [False] + scheduler.ec_connector.has_cache_item = Mock( + side_effect=lambda hash_value: hash_value in mm_hashes_list_2[0] ) scheduler.ec_connector.update_state_after_alloc = Mock( wraps=scheduler.ec_connector.update_state_after_alloc diff --git a/tests/v1/ec_connector/unit/test_ec_example_connector.py b/tests/v1/ec_connector/unit/test_ec_example_connector.py index 9ed82e1ce..c5686cf9f 100644 --- a/tests/v1/ec_connector/unit/test_ec_example_connector.py +++ b/tests/v1/ec_connector/unit/test_ec_example_connector.py @@ -123,15 +123,15 @@ class TestECExampleConnectorBasics: class TestCacheExistence: - """Test cache existence checking using has_caches() API.""" + """Test cache existence checking using has_cache_item() API.""" - def test_has_caches_all_exist_3_items( + def test_has_cache_item_all_exist_3_items( self, mock_vllm_config_producer, mock_vllm_config_consumer, mock_request_with_3_mm, ): - """Test has_caches returns True when all 3 caches exist.""" + """Test has_cache_item returns True when all 3 caches exist.""" # Test for producer first producer = ECExampleConnector( vllm_config=mock_vllm_config_producer, @@ -146,8 +146,11 @@ class TestCacheExistence: encoder_cache[mm_hash] = torch.randn(10, 768) producer.save_caches(encoder_cache, mm_hash) - # Test using has_caches API - producer_result = producer.has_caches(mock_request_with_3_mm) + # Test using has_cache_item API + producer_result = [ + producer.has_cache_item(mm_feature.identifier) + for mm_feature in mock_request_with_3_mm.mm_features + ] # Assert assert len(producer_result) == 3 @@ -159,14 +162,17 @@ class TestCacheExistence: role=ECConnectorRole.SCHEDULER, ) - # Test using has_caches API - consumer_result = consumer.has_caches(mock_request_with_3_mm) + # Test using has_cache_item API + consumer_result = [ + consumer.has_cache_item(mm_feature.identifier) + for mm_feature in mock_request_with_3_mm.mm_features + ] # Assert assert len(consumer_result) == 3 assert all(consumer_result), f"Expected all True, got {consumer_result}" - def test_has_caches_none_exist( + def test_has_cache_item_none_exist( self, mock_vllm_config_producer, mock_request_with_3_mm ): """Test has_caches returns False when no caches exist.""" @@ -176,13 +182,16 @@ class TestCacheExistence: ) # Test without creating any files - result = connector.has_caches(mock_request_with_3_mm) + result = [ + connector.has_cache_item(mm_feature.identifier) + for mm_feature in mock_request_with_3_mm.mm_features + ] # Assert assert len(result) == 3 assert not any(result), f"Expected all False, got {result}" - def test_has_caches_partial_exist( + def test_has_cache_item_partial_exist( self, mock_vllm_config_producer, mock_request_with_3_mm ): """Test has_caches with some caches existing (1 of 3).""" @@ -197,7 +206,10 @@ class TestCacheExistence: connector.save_caches(encoder_cache, mm_hash_second) # Test - result = connector.has_caches(mock_request_with_3_mm) + result = [ + connector.has_cache_item(mm_feature.identifier) + for mm_feature in mock_request_with_3_mm.mm_features + ] # Assert assert len(result) == 3 @@ -323,8 +335,11 @@ class TestCacheSaving: encoder_cache[mm_hash] = torch.randn(10, 768) connector.save_caches(encoder_cache, mm_hash) - # Verify all files exist using has_caches - result = connector.has_caches(mock_request_with_3_mm) + # Verify all files exist using has_cache_item + result = [ + connector.has_cache_item(mm_feature.identifier) + for mm_feature in mock_request_with_3_mm.mm_features + ] assert all(result), f"Not all caches were saved: {result}" # Verify each file's content @@ -347,10 +362,9 @@ class TestCacheSaving: # Save should not raise but also not create file connector.save_caches(encoder_cache, mm_hash) - # Verify file doesn't exist using has_caches - mock_request = MockRequest("req_consumer", [mm_hash], [10]) - result = connector.has_caches(mock_request) - assert not result[0], "Consumer should not save caches" + # Verify file doesn't exist using has_cache_item + result = connector.has_cache_item(mm_hash) + assert not result, "Consumer should not save caches" class TestCacheLoading: diff --git a/vllm/distributed/ec_transfer/ec_connector/base.py b/vllm/distributed/ec_transfer/ec_connector/base.py index 2b7b14d89..2c212c29c 100644 --- a/vllm/distributed/ec_transfer/ec_connector/base.py +++ b/vllm/distributed/ec_transfer/ec_connector/base.py @@ -182,19 +182,19 @@ class ECConnectorBase(ABC): # ============================== @abstractmethod - def has_caches( + def has_cache_item( self, - request: "Request", - ) -> list[bool]: + identifier: str, + ) -> bool: """ - Check if encoder cache exists for each mm data of requests + Check if a single encoder cache exists Args: - request (Request): the request object. + identifier (str): the identifier of the media. Returns: - A list bool where ith value is True if cache exist for - ith mm_data of requests + A bool where value is True if cache exist for + the media """ pass diff --git a/vllm/distributed/ec_transfer/ec_connector/example_connector.py b/vllm/distributed/ec_transfer/ec_connector/example_connector.py index 48a7d4190..92f190b54 100644 --- a/vllm/distributed/ec_transfer/ec_connector/example_connector.py +++ b/vllm/distributed/ec_transfer/ec_connector/example_connector.py @@ -117,23 +117,20 @@ class ECExampleConnector(ECConnectorBase): safetensors.torch.save_file(tensors, filename) logger.debug("Save cache successful for mm_hash %s", mm_hash) - def has_caches( + def has_cache_item( self, - request: "Request", - ) -> list[bool]: + identifier: str, + ) -> bool: """ - Check if cache exist externally for each mm_data of request + Check if cache exist externally for the media Args: - request (Request): the request object. + identifier (str): the identifier of the media. Returns: - List of bool indicate that ith mm_data exist in cache or not + Bool indicate that media exists in cache or not """ - result = [] - for feature in request.mm_features: - result.append(self._found_match_for_mm_data(feature.identifier)) - return result + return self._found_match_for_mm_data(identifier) def update_state_after_alloc( self, diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 0cb65bd0f..56c10dab5 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -947,9 +947,6 @@ class Scheduler(SchedulerInterface): assert len(mm_features) > 0 external_load_encoder_input = [] - # Check remote cache first - if self.ec_connector is not None: - remote_cache_has_item = self.ec_connector.has_caches(request) # NOTE: since scheduler operates on the request level (possibly with # multiple encoder inputs per request), we need to create temporary # trackers for accounting at the encoder input level. @@ -959,6 +956,7 @@ class Scheduler(SchedulerInterface): start_pos = mm_feature.mm_position.offset num_encoder_tokens = mm_feature.mm_position.length num_encoder_embeds = mm_feature.mm_position.get_num_embeds + item_identifier = mm_feature.identifier # The encoder output is needed if the two ranges overlap: # [num_computed_tokens, num_computed_tokens + num_new_tokens) and @@ -993,7 +991,7 @@ class Scheduler(SchedulerInterface): if not self.is_encoder_decoder: # We are not using the encoder cache for encoder-decoder models, # yet. - if request.mm_features[i].identifier in mm_hashes_to_schedule: + if item_identifier in mm_hashes_to_schedule: # The same encoder input has already been scheduled in the # current step. continue @@ -1051,15 +1049,17 @@ class Scheduler(SchedulerInterface): if curr_embeds_end - curr_embeds_start == 0: continue - if self.ec_connector is not None and remote_cache_has_item[i]: - mm_hashes_to_schedule.add(request.mm_features[i].identifier) + if self.ec_connector is not None and self.ec_connector.has_cache_item( + item_identifier + ): + mm_hashes_to_schedule.add(item_identifier) external_load_encoder_input.append(i) num_embeds_to_schedule += num_encoder_embeds continue num_embeds_to_schedule += num_encoder_embeds encoder_compute_budget -= num_encoder_embeds - mm_hashes_to_schedule.add(request.mm_features[i].identifier) + mm_hashes_to_schedule.add(item_identifier) encoder_inputs_to_schedule.append(i) return (