[EC Connector] Optimize remote cache check in scheduler (#32585)

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
knlnguyen1802
2026-01-22 11:30:59 +08:00
committed by GitHub
parent c5487e2b96
commit 378385b90c
5 changed files with 76 additions and 57 deletions

View File

@@ -2560,15 +2560,14 @@ def test_ec_connector_cache_hit_external_load(use_kv_connector):
mm_positions=mm_positions, mm_positions=mm_positions,
)[0] )[0]
# Mock cache hit - encoder cache exists externally # Mock cache hit - encoder cache has_exists externally
scheduler.ec_connector.has_caches = Mock(return_value=[True]) scheduler.ec_connector.has_cache_item = Mock(return_value=True)
scheduler.ec_connector.update_state_after_alloc = Mock( scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc wraps=scheduler.ec_connector.update_state_after_alloc
) )
scheduler.add_request(request) scheduler.add_request(request)
output = scheduler.schedule() output = scheduler.schedule()
# Should schedule prompt tokens # Should schedule prompt tokens
scheduled_tokens = output.num_scheduled_tokens[request.request_id] scheduled_tokens = output.num_scheduled_tokens[request.request_id]
assert scheduled_tokens == NUM_TOKENS assert scheduled_tokens == NUM_TOKENS
@@ -2611,7 +2610,7 @@ def test_ec_connector_cache_miss_computes_locally(use_kv_connector):
)[0] )[0]
# Mock cache miss - encoder cache doesn't exist externally # Mock cache miss - encoder cache doesn't exist externally
scheduler.ec_connector.has_caches = Mock(return_value=[False]) scheduler.ec_connector.has_cache_item = Mock(return_value=False)
scheduler.add_request(request_mm_missed) scheduler.add_request(request_mm_missed)
output = scheduler.schedule() output = scheduler.schedule()
@@ -2665,7 +2664,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1), PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1),
] ]
] ]
has_cache_item_result_map_1 = {"hash1_A": False, "hash1_B": True, "hash1_F": True}
# Create request with 4 MM items, with 2 identical items # Create request with 4 MM items, with 2 identical items
request1 = create_requests( request1 = create_requests(
num_requests=1, num_requests=1,
@@ -2676,7 +2675,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
)[0] )[0]
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist # Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True]) scheduler.ec_connector.has_cache_item = Mock(
side_effect=lambda hash_val: has_cache_item_result_map_1[hash_val]
)
scheduler.ec_connector.update_state_after_alloc = Mock( scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc wraps=scheduler.ec_connector.update_state_after_alloc
) )
@@ -2736,7 +2737,12 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2), PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2),
] ]
] ]
has_cache_item_result_map_2 = {
"hash1_C": True,
"hash1_D": False,
"hash1_E": False,
"hash1_A": True,
}
request2 = create_requests( request2 = create_requests(
num_requests=1, num_requests=1,
num_tokens=NUM_TOKENS_2, num_tokens=NUM_TOKENS_2,
@@ -2746,7 +2752,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
)[0] )[0]
# Mock partial cache hit: only hash1_A and hash1_C exist in connector # Mock partial cache hit: only hash1_A and hash1_C exist in connector
scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True]) scheduler.ec_connector.has_cache_item = Mock(
side_effect=lambda hash_val: has_cache_item_result_map_2[hash_val]
)
scheduler.add_request(request2) scheduler.add_request(request2)
output = scheduler.schedule() output = scheduler.schedule()
@@ -2821,9 +2829,9 @@ def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
if cache_exist == "connector_only": if cache_exist == "connector_only":
# Cache exist in ec_connector # Cache exist in ec_connector
scheduler.ec_connector.has_caches = Mock(return_value=[True]) scheduler.ec_connector.has_cache_item = Mock(return_value=True)
elif cache_exist == "no_where": elif cache_exist == "no_where":
scheduler.ec_connector.has_caches = Mock(return_value=[False]) scheduler.ec_connector.has_cache_item = Mock(return_value=False)
output = scheduler.schedule() output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests) assert len(output.scheduled_new_reqs) == len(requests)
@@ -2887,7 +2895,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
) )
# Mock ec_connector load external cache behavior # Mock ec_connector load external cache behavior
scheduler.ec_connector.has_caches = Mock(return_value=[True]) scheduler.ec_connector.has_cache_item = Mock(return_value=True)
scheduler.ec_connector.update_state_after_alloc = Mock( scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc wraps=scheduler.ec_connector.update_state_after_alloc
) )
@@ -2984,7 +2992,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
) )
# Mock cache hit: Both cache exist in connector (at E->PD initially) # Mock cache hit: Both cache exist in connector (at E->PD initially)
scheduler.ec_connector.has_caches = Mock(return_value=[True]) scheduler.ec_connector.has_cache_item = Mock(return_value=True)
scheduler.ec_connector.update_state_after_alloc = Mock( scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc wraps=scheduler.ec_connector.update_state_after_alloc
) )
@@ -3139,9 +3147,9 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
if cache_exist == "connector_only": if cache_exist == "connector_only":
# Cache exist in ec_connector # Cache exist in ec_connector
scheduler.ec_connector.has_caches = Mock(return_value=[True]) scheduler.ec_connector.has_cache_item = Mock(return_value=True)
elif cache_exist == "no_where": elif cache_exist == "no_where":
scheduler.ec_connector.has_caches = Mock(return_value=[False]) scheduler.ec_connector.has_cache_item = Mock(return_value=False)
# 4th Schedule - this should trigger req_low resumption from waiting # 4th Schedule - this should trigger req_low resumption from waiting
output = scheduler.schedule() output = scheduler.schedule()
@@ -3259,8 +3267,8 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
)[0] )[0]
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely # Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
scheduler.ec_connector.has_caches = Mock( scheduler.ec_connector.has_cache_item = Mock(
side_effect=lambda req: [True, True, True] if req == request2 else [False] side_effect=lambda hash_value: hash_value in mm_hashes_list_2[0]
) )
scheduler.ec_connector.update_state_after_alloc = Mock( scheduler.ec_connector.update_state_after_alloc = Mock(
wraps=scheduler.ec_connector.update_state_after_alloc wraps=scheduler.ec_connector.update_state_after_alloc

View File

@@ -123,15 +123,15 @@ class TestECExampleConnectorBasics:
class TestCacheExistence: class TestCacheExistence:
"""Test cache existence checking using has_caches() API.""" """Test cache existence checking using has_cache_item() API."""
def test_has_caches_all_exist_3_items( def test_has_cache_item_all_exist_3_items(
self, self,
mock_vllm_config_producer, mock_vllm_config_producer,
mock_vllm_config_consumer, mock_vllm_config_consumer,
mock_request_with_3_mm, mock_request_with_3_mm,
): ):
"""Test has_caches returns True when all 3 caches exist.""" """Test has_cache_item returns True when all 3 caches exist."""
# Test for producer first # Test for producer first
producer = ECExampleConnector( producer = ECExampleConnector(
vllm_config=mock_vllm_config_producer, vllm_config=mock_vllm_config_producer,
@@ -146,8 +146,11 @@ class TestCacheExistence:
encoder_cache[mm_hash] = torch.randn(10, 768) encoder_cache[mm_hash] = torch.randn(10, 768)
producer.save_caches(encoder_cache, mm_hash) producer.save_caches(encoder_cache, mm_hash)
# Test using has_caches API # Test using has_cache_item API
producer_result = producer.has_caches(mock_request_with_3_mm) producer_result = [
producer.has_cache_item(mm_feature.identifier)
for mm_feature in mock_request_with_3_mm.mm_features
]
# Assert # Assert
assert len(producer_result) == 3 assert len(producer_result) == 3
@@ -159,14 +162,17 @@ class TestCacheExistence:
role=ECConnectorRole.SCHEDULER, role=ECConnectorRole.SCHEDULER,
) )
# Test using has_caches API # Test using has_cache_item API
consumer_result = consumer.has_caches(mock_request_with_3_mm) consumer_result = [
consumer.has_cache_item(mm_feature.identifier)
for mm_feature in mock_request_with_3_mm.mm_features
]
# Assert # Assert
assert len(consumer_result) == 3 assert len(consumer_result) == 3
assert all(consumer_result), f"Expected all True, got {consumer_result}" assert all(consumer_result), f"Expected all True, got {consumer_result}"
def test_has_caches_none_exist( def test_has_cache_item_none_exist(
self, mock_vllm_config_producer, mock_request_with_3_mm self, mock_vllm_config_producer, mock_request_with_3_mm
): ):
"""Test has_caches returns False when no caches exist.""" """Test has_caches returns False when no caches exist."""
@@ -176,13 +182,16 @@ class TestCacheExistence:
) )
# Test without creating any files # Test without creating any files
result = connector.has_caches(mock_request_with_3_mm) result = [
connector.has_cache_item(mm_feature.identifier)
for mm_feature in mock_request_with_3_mm.mm_features
]
# Assert # Assert
assert len(result) == 3 assert len(result) == 3
assert not any(result), f"Expected all False, got {result}" assert not any(result), f"Expected all False, got {result}"
def test_has_caches_partial_exist( def test_has_cache_item_partial_exist(
self, mock_vllm_config_producer, mock_request_with_3_mm self, mock_vllm_config_producer, mock_request_with_3_mm
): ):
"""Test has_caches with some caches existing (1 of 3).""" """Test has_caches with some caches existing (1 of 3)."""
@@ -197,7 +206,10 @@ class TestCacheExistence:
connector.save_caches(encoder_cache, mm_hash_second) connector.save_caches(encoder_cache, mm_hash_second)
# Test # Test
result = connector.has_caches(mock_request_with_3_mm) result = [
connector.has_cache_item(mm_feature.identifier)
for mm_feature in mock_request_with_3_mm.mm_features
]
# Assert # Assert
assert len(result) == 3 assert len(result) == 3
@@ -323,8 +335,11 @@ class TestCacheSaving:
encoder_cache[mm_hash] = torch.randn(10, 768) encoder_cache[mm_hash] = torch.randn(10, 768)
connector.save_caches(encoder_cache, mm_hash) connector.save_caches(encoder_cache, mm_hash)
# Verify all files exist using has_caches # Verify all files exist using has_cache_item
result = connector.has_caches(mock_request_with_3_mm) result = [
connector.has_cache_item(mm_feature.identifier)
for mm_feature in mock_request_with_3_mm.mm_features
]
assert all(result), f"Not all caches were saved: {result}" assert all(result), f"Not all caches were saved: {result}"
# Verify each file's content # Verify each file's content
@@ -347,10 +362,9 @@ class TestCacheSaving:
# Save should not raise but also not create file # Save should not raise but also not create file
connector.save_caches(encoder_cache, mm_hash) connector.save_caches(encoder_cache, mm_hash)
# Verify file doesn't exist using has_caches # Verify file doesn't exist using has_cache_item
mock_request = MockRequest("req_consumer", [mm_hash], [10]) result = connector.has_cache_item(mm_hash)
result = connector.has_caches(mock_request) assert not result, "Consumer should not save caches"
assert not result[0], "Consumer should not save caches"
class TestCacheLoading: class TestCacheLoading:

View File

@@ -182,19 +182,19 @@ class ECConnectorBase(ABC):
# ============================== # ==============================
@abstractmethod @abstractmethod
def has_caches( def has_cache_item(
self, self,
request: "Request", identifier: str,
) -> list[bool]: ) -> bool:
""" """
Check if encoder cache exists for each mm data of requests Check if a single encoder cache exists
Args: Args:
request (Request): the request object. identifier (str): the identifier of the media.
Returns: Returns:
A list bool where ith value is True if cache exist for A bool where value is True if cache exist for
ith mm_data of requests the media
""" """
pass pass

View File

@@ -117,23 +117,20 @@ class ECExampleConnector(ECConnectorBase):
safetensors.torch.save_file(tensors, filename) safetensors.torch.save_file(tensors, filename)
logger.debug("Save cache successful for mm_hash %s", mm_hash) logger.debug("Save cache successful for mm_hash %s", mm_hash)
def has_caches( def has_cache_item(
self, self,
request: "Request", identifier: str,
) -> list[bool]: ) -> bool:
""" """
Check if cache exist externally for each mm_data of request Check if cache exist externally for the media
Args: Args:
request (Request): the request object. identifier (str): the identifier of the media.
Returns: Returns:
List of bool indicate that ith mm_data exist in cache or not Bool indicate that media exists in cache or not
""" """
result = [] return self._found_match_for_mm_data(identifier)
for feature in request.mm_features:
result.append(self._found_match_for_mm_data(feature.identifier))
return result
def update_state_after_alloc( def update_state_after_alloc(
self, self,

View File

@@ -947,9 +947,6 @@ class Scheduler(SchedulerInterface):
assert len(mm_features) > 0 assert len(mm_features) > 0
external_load_encoder_input = [] external_load_encoder_input = []
# Check remote cache first
if self.ec_connector is not None:
remote_cache_has_item = self.ec_connector.has_caches(request)
# NOTE: since scheduler operates on the request level (possibly with # NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary # multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level. # trackers for accounting at the encoder input level.
@@ -959,6 +956,7 @@ class Scheduler(SchedulerInterface):
start_pos = mm_feature.mm_position.offset start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds num_encoder_embeds = mm_feature.mm_position.get_num_embeds
item_identifier = mm_feature.identifier
# The encoder output is needed if the two ranges overlap: # The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and # [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@@ -993,7 +991,7 @@ class Scheduler(SchedulerInterface):
if not self.is_encoder_decoder: if not self.is_encoder_decoder:
# We are not using the encoder cache for encoder-decoder models, # We are not using the encoder cache for encoder-decoder models,
# yet. # yet.
if request.mm_features[i].identifier in mm_hashes_to_schedule: if item_identifier in mm_hashes_to_schedule:
# The same encoder input has already been scheduled in the # The same encoder input has already been scheduled in the
# current step. # current step.
continue continue
@@ -1051,15 +1049,17 @@ class Scheduler(SchedulerInterface):
if curr_embeds_end - curr_embeds_start == 0: if curr_embeds_end - curr_embeds_start == 0:
continue continue
if self.ec_connector is not None and remote_cache_has_item[i]: if self.ec_connector is not None and self.ec_connector.has_cache_item(
mm_hashes_to_schedule.add(request.mm_features[i].identifier) item_identifier
):
mm_hashes_to_schedule.add(item_identifier)
external_load_encoder_input.append(i) external_load_encoder_input.append(i)
num_embeds_to_schedule += num_encoder_embeds num_embeds_to_schedule += num_encoder_embeds
continue continue
num_embeds_to_schedule += num_encoder_embeds num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier) mm_hashes_to_schedule.add(item_identifier)
encoder_inputs_to_schedule.append(i) encoder_inputs_to_schedule.append(i)
return ( return (