[EC Connector] Optimize remote cache check in scheduler (#32585)
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
@@ -2560,15 +2560,14 @@ def test_ec_connector_cache_hit_external_load(use_kv_connector):
|
|||||||
mm_positions=mm_positions,
|
mm_positions=mm_positions,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Mock cache hit - encoder cache exists externally
|
# Mock cache hit - encoder cache has_exists externally
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=True)
|
||||||
scheduler.ec_connector.update_state_after_alloc = Mock(
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
||||||
wraps=scheduler.ec_connector.update_state_after_alloc
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
||||||
)
|
)
|
||||||
|
|
||||||
scheduler.add_request(request)
|
scheduler.add_request(request)
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
|
|
||||||
# Should schedule prompt tokens
|
# Should schedule prompt tokens
|
||||||
scheduled_tokens = output.num_scheduled_tokens[request.request_id]
|
scheduled_tokens = output.num_scheduled_tokens[request.request_id]
|
||||||
assert scheduled_tokens == NUM_TOKENS
|
assert scheduled_tokens == NUM_TOKENS
|
||||||
@@ -2611,7 +2610,7 @@ def test_ec_connector_cache_miss_computes_locally(use_kv_connector):
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Mock cache miss - encoder cache doesn't exist externally
|
# Mock cache miss - encoder cache doesn't exist externally
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=False)
|
||||||
|
|
||||||
scheduler.add_request(request_mm_missed)
|
scheduler.add_request(request_mm_missed)
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
@@ -2665,7 +2664,7 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
|
|||||||
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1),
|
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_1),
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
has_cache_item_result_map_1 = {"hash1_A": False, "hash1_B": True, "hash1_F": True}
|
||||||
# Create request with 4 MM items, with 2 identical items
|
# Create request with 4 MM items, with 2 identical items
|
||||||
request1 = create_requests(
|
request1 = create_requests(
|
||||||
num_requests=1,
|
num_requests=1,
|
||||||
@@ -2676,7 +2675,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
|
# Mock partial cache hit: 1st and 3rd missing, 2nd and 4th exist
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[False, True, False, True])
|
scheduler.ec_connector.has_cache_item = Mock(
|
||||||
|
side_effect=lambda hash_val: has_cache_item_result_map_1[hash_val]
|
||||||
|
)
|
||||||
scheduler.ec_connector.update_state_after_alloc = Mock(
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
||||||
wraps=scheduler.ec_connector.update_state_after_alloc
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
||||||
)
|
)
|
||||||
@@ -2736,7 +2737,12 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
|
|||||||
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2),
|
PlaceholderRange(offset=250, length=NUM_ENCODER_TOKENS_2),
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
has_cache_item_result_map_2 = {
|
||||||
|
"hash1_C": True,
|
||||||
|
"hash1_D": False,
|
||||||
|
"hash1_E": False,
|
||||||
|
"hash1_A": True,
|
||||||
|
}
|
||||||
request2 = create_requests(
|
request2 = create_requests(
|
||||||
num_requests=1,
|
num_requests=1,
|
||||||
num_tokens=NUM_TOKENS_2,
|
num_tokens=NUM_TOKENS_2,
|
||||||
@@ -2746,7 +2752,9 @@ def test_ec_connector_with_partial_cache_hit_multi_round(use_kv_connector):
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Mock partial cache hit: only hash1_A and hash1_C exist in connector
|
# Mock partial cache hit: only hash1_A and hash1_C exist in connector
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[True, False, False, True])
|
scheduler.ec_connector.has_cache_item = Mock(
|
||||||
|
side_effect=lambda hash_val: has_cache_item_result_map_2[hash_val]
|
||||||
|
)
|
||||||
|
|
||||||
scheduler.add_request(request2)
|
scheduler.add_request(request2)
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
@@ -2821,9 +2829,9 @@ def test_ec_connector_schedule_multiple_requests(cache_exist, use_kv_connector):
|
|||||||
|
|
||||||
if cache_exist == "connector_only":
|
if cache_exist == "connector_only":
|
||||||
# Cache exist in ec_connector
|
# Cache exist in ec_connector
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=True)
|
||||||
elif cache_exist == "no_where":
|
elif cache_exist == "no_where":
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=False)
|
||||||
|
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
assert len(output.scheduled_new_reqs) == len(requests)
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
@@ -2887,7 +2895,7 @@ def test_ec_connector_unable_to_allocate(use_kv_connector):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mock ec_connector load external cache behavior
|
# Mock ec_connector load external cache behavior
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=True)
|
||||||
scheduler.ec_connector.update_state_after_alloc = Mock(
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
||||||
wraps=scheduler.ec_connector.update_state_after_alloc
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
||||||
)
|
)
|
||||||
@@ -2984,7 +2992,7 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Mock cache hit: Both cache exist in connector (at E->PD initially)
|
# Mock cache hit: Both cache exist in connector (at E->PD initially)
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=True)
|
||||||
scheduler.ec_connector.update_state_after_alloc = Mock(
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
||||||
wraps=scheduler.ec_connector.update_state_after_alloc
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
||||||
)
|
)
|
||||||
@@ -3139,9 +3147,9 @@ def test_priority_scheduling_ec_connector_preemption_and_resumption(
|
|||||||
|
|
||||||
if cache_exist == "connector_only":
|
if cache_exist == "connector_only":
|
||||||
# Cache exist in ec_connector
|
# Cache exist in ec_connector
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[True])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=True)
|
||||||
elif cache_exist == "no_where":
|
elif cache_exist == "no_where":
|
||||||
scheduler.ec_connector.has_caches = Mock(return_value=[False])
|
scheduler.ec_connector.has_cache_item = Mock(return_value=False)
|
||||||
|
|
||||||
# 4th Schedule - this should trigger req_low resumption from waiting
|
# 4th Schedule - this should trigger req_low resumption from waiting
|
||||||
output = scheduler.schedule()
|
output = scheduler.schedule()
|
||||||
@@ -3259,8 +3267,8 @@ def test_ec_connector_allocate_encoder_tokens_with_external_load(use_kv_connecto
|
|||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
|
# Mock cache hit: MM of request1 NOT cached remotely, request2 cached remotely
|
||||||
scheduler.ec_connector.has_caches = Mock(
|
scheduler.ec_connector.has_cache_item = Mock(
|
||||||
side_effect=lambda req: [True, True, True] if req == request2 else [False]
|
side_effect=lambda hash_value: hash_value in mm_hashes_list_2[0]
|
||||||
)
|
)
|
||||||
scheduler.ec_connector.update_state_after_alloc = Mock(
|
scheduler.ec_connector.update_state_after_alloc = Mock(
|
||||||
wraps=scheduler.ec_connector.update_state_after_alloc
|
wraps=scheduler.ec_connector.update_state_after_alloc
|
||||||
|
|||||||
@@ -123,15 +123,15 @@ class TestECExampleConnectorBasics:
|
|||||||
|
|
||||||
|
|
||||||
class TestCacheExistence:
|
class TestCacheExistence:
|
||||||
"""Test cache existence checking using has_caches() API."""
|
"""Test cache existence checking using has_cache_item() API."""
|
||||||
|
|
||||||
def test_has_caches_all_exist_3_items(
|
def test_has_cache_item_all_exist_3_items(
|
||||||
self,
|
self,
|
||||||
mock_vllm_config_producer,
|
mock_vllm_config_producer,
|
||||||
mock_vllm_config_consumer,
|
mock_vllm_config_consumer,
|
||||||
mock_request_with_3_mm,
|
mock_request_with_3_mm,
|
||||||
):
|
):
|
||||||
"""Test has_caches returns True when all 3 caches exist."""
|
"""Test has_cache_item returns True when all 3 caches exist."""
|
||||||
# Test for producer first
|
# Test for producer first
|
||||||
producer = ECExampleConnector(
|
producer = ECExampleConnector(
|
||||||
vllm_config=mock_vllm_config_producer,
|
vllm_config=mock_vllm_config_producer,
|
||||||
@@ -146,8 +146,11 @@ class TestCacheExistence:
|
|||||||
encoder_cache[mm_hash] = torch.randn(10, 768)
|
encoder_cache[mm_hash] = torch.randn(10, 768)
|
||||||
producer.save_caches(encoder_cache, mm_hash)
|
producer.save_caches(encoder_cache, mm_hash)
|
||||||
|
|
||||||
# Test using has_caches API
|
# Test using has_cache_item API
|
||||||
producer_result = producer.has_caches(mock_request_with_3_mm)
|
producer_result = [
|
||||||
|
producer.has_cache_item(mm_feature.identifier)
|
||||||
|
for mm_feature in mock_request_with_3_mm.mm_features
|
||||||
|
]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(producer_result) == 3
|
assert len(producer_result) == 3
|
||||||
@@ -159,14 +162,17 @@ class TestCacheExistence:
|
|||||||
role=ECConnectorRole.SCHEDULER,
|
role=ECConnectorRole.SCHEDULER,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test using has_caches API
|
# Test using has_cache_item API
|
||||||
consumer_result = consumer.has_caches(mock_request_with_3_mm)
|
consumer_result = [
|
||||||
|
consumer.has_cache_item(mm_feature.identifier)
|
||||||
|
for mm_feature in mock_request_with_3_mm.mm_features
|
||||||
|
]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(consumer_result) == 3
|
assert len(consumer_result) == 3
|
||||||
assert all(consumer_result), f"Expected all True, got {consumer_result}"
|
assert all(consumer_result), f"Expected all True, got {consumer_result}"
|
||||||
|
|
||||||
def test_has_caches_none_exist(
|
def test_has_cache_item_none_exist(
|
||||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||||
):
|
):
|
||||||
"""Test has_caches returns False when no caches exist."""
|
"""Test has_caches returns False when no caches exist."""
|
||||||
@@ -176,13 +182,16 @@ class TestCacheExistence:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Test without creating any files
|
# Test without creating any files
|
||||||
result = connector.has_caches(mock_request_with_3_mm)
|
result = [
|
||||||
|
connector.has_cache_item(mm_feature.identifier)
|
||||||
|
for mm_feature in mock_request_with_3_mm.mm_features
|
||||||
|
]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
assert not any(result), f"Expected all False, got {result}"
|
assert not any(result), f"Expected all False, got {result}"
|
||||||
|
|
||||||
def test_has_caches_partial_exist(
|
def test_has_cache_item_partial_exist(
|
||||||
self, mock_vllm_config_producer, mock_request_with_3_mm
|
self, mock_vllm_config_producer, mock_request_with_3_mm
|
||||||
):
|
):
|
||||||
"""Test has_caches with some caches existing (1 of 3)."""
|
"""Test has_caches with some caches existing (1 of 3)."""
|
||||||
@@ -197,7 +206,10 @@ class TestCacheExistence:
|
|||||||
connector.save_caches(encoder_cache, mm_hash_second)
|
connector.save_caches(encoder_cache, mm_hash_second)
|
||||||
|
|
||||||
# Test
|
# Test
|
||||||
result = connector.has_caches(mock_request_with_3_mm)
|
result = [
|
||||||
|
connector.has_cache_item(mm_feature.identifier)
|
||||||
|
for mm_feature in mock_request_with_3_mm.mm_features
|
||||||
|
]
|
||||||
|
|
||||||
# Assert
|
# Assert
|
||||||
assert len(result) == 3
|
assert len(result) == 3
|
||||||
@@ -323,8 +335,11 @@ class TestCacheSaving:
|
|||||||
encoder_cache[mm_hash] = torch.randn(10, 768)
|
encoder_cache[mm_hash] = torch.randn(10, 768)
|
||||||
connector.save_caches(encoder_cache, mm_hash)
|
connector.save_caches(encoder_cache, mm_hash)
|
||||||
|
|
||||||
# Verify all files exist using has_caches
|
# Verify all files exist using has_cache_item
|
||||||
result = connector.has_caches(mock_request_with_3_mm)
|
result = [
|
||||||
|
connector.has_cache_item(mm_feature.identifier)
|
||||||
|
for mm_feature in mock_request_with_3_mm.mm_features
|
||||||
|
]
|
||||||
assert all(result), f"Not all caches were saved: {result}"
|
assert all(result), f"Not all caches were saved: {result}"
|
||||||
|
|
||||||
# Verify each file's content
|
# Verify each file's content
|
||||||
@@ -347,10 +362,9 @@ class TestCacheSaving:
|
|||||||
# Save should not raise but also not create file
|
# Save should not raise but also not create file
|
||||||
connector.save_caches(encoder_cache, mm_hash)
|
connector.save_caches(encoder_cache, mm_hash)
|
||||||
|
|
||||||
# Verify file doesn't exist using has_caches
|
# Verify file doesn't exist using has_cache_item
|
||||||
mock_request = MockRequest("req_consumer", [mm_hash], [10])
|
result = connector.has_cache_item(mm_hash)
|
||||||
result = connector.has_caches(mock_request)
|
assert not result, "Consumer should not save caches"
|
||||||
assert not result[0], "Consumer should not save caches"
|
|
||||||
|
|
||||||
|
|
||||||
class TestCacheLoading:
|
class TestCacheLoading:
|
||||||
|
|||||||
@@ -182,19 +182,19 @@ class ECConnectorBase(ABC):
|
|||||||
# ==============================
|
# ==============================
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def has_caches(
|
def has_cache_item(
|
||||||
self,
|
self,
|
||||||
request: "Request",
|
identifier: str,
|
||||||
) -> list[bool]:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if encoder cache exists for each mm data of requests
|
Check if a single encoder cache exists
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (Request): the request object.
|
identifier (str): the identifier of the media.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list bool where ith value is True if cache exist for
|
A bool where value is True if cache exist for
|
||||||
ith mm_data of requests
|
the media
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|||||||
@@ -117,23 +117,20 @@ class ECExampleConnector(ECConnectorBase):
|
|||||||
safetensors.torch.save_file(tensors, filename)
|
safetensors.torch.save_file(tensors, filename)
|
||||||
logger.debug("Save cache successful for mm_hash %s", mm_hash)
|
logger.debug("Save cache successful for mm_hash %s", mm_hash)
|
||||||
|
|
||||||
def has_caches(
|
def has_cache_item(
|
||||||
self,
|
self,
|
||||||
request: "Request",
|
identifier: str,
|
||||||
) -> list[bool]:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if cache exist externally for each mm_data of request
|
Check if cache exist externally for the media
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request (Request): the request object.
|
identifier (str): the identifier of the media.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of bool indicate that ith mm_data exist in cache or not
|
Bool indicate that media exists in cache or not
|
||||||
"""
|
"""
|
||||||
result = []
|
return self._found_match_for_mm_data(identifier)
|
||||||
for feature in request.mm_features:
|
|
||||||
result.append(self._found_match_for_mm_data(feature.identifier))
|
|
||||||
return result
|
|
||||||
|
|
||||||
def update_state_after_alloc(
|
def update_state_after_alloc(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -947,9 +947,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
assert len(mm_features) > 0
|
assert len(mm_features) > 0
|
||||||
external_load_encoder_input = []
|
external_load_encoder_input = []
|
||||||
|
|
||||||
# Check remote cache first
|
|
||||||
if self.ec_connector is not None:
|
|
||||||
remote_cache_has_item = self.ec_connector.has_caches(request)
|
|
||||||
# NOTE: since scheduler operates on the request level (possibly with
|
# NOTE: since scheduler operates on the request level (possibly with
|
||||||
# multiple encoder inputs per request), we need to create temporary
|
# multiple encoder inputs per request), we need to create temporary
|
||||||
# trackers for accounting at the encoder input level.
|
# trackers for accounting at the encoder input level.
|
||||||
@@ -959,6 +956,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
start_pos = mm_feature.mm_position.offset
|
start_pos = mm_feature.mm_position.offset
|
||||||
num_encoder_tokens = mm_feature.mm_position.length
|
num_encoder_tokens = mm_feature.mm_position.length
|
||||||
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||||
|
item_identifier = mm_feature.identifier
|
||||||
|
|
||||||
# The encoder output is needed if the two ranges overlap:
|
# The encoder output is needed if the two ranges overlap:
|
||||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||||
@@ -993,7 +991,7 @@ class Scheduler(SchedulerInterface):
|
|||||||
if not self.is_encoder_decoder:
|
if not self.is_encoder_decoder:
|
||||||
# We are not using the encoder cache for encoder-decoder models,
|
# We are not using the encoder cache for encoder-decoder models,
|
||||||
# yet.
|
# yet.
|
||||||
if request.mm_features[i].identifier in mm_hashes_to_schedule:
|
if item_identifier in mm_hashes_to_schedule:
|
||||||
# The same encoder input has already been scheduled in the
|
# The same encoder input has already been scheduled in the
|
||||||
# current step.
|
# current step.
|
||||||
continue
|
continue
|
||||||
@@ -1051,15 +1049,17 @@ class Scheduler(SchedulerInterface):
|
|||||||
if curr_embeds_end - curr_embeds_start == 0:
|
if curr_embeds_end - curr_embeds_start == 0:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
if self.ec_connector is not None and self.ec_connector.has_cache_item(
|
||||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
item_identifier
|
||||||
|
):
|
||||||
|
mm_hashes_to_schedule.add(item_identifier)
|
||||||
external_load_encoder_input.append(i)
|
external_load_encoder_input.append(i)
|
||||||
num_embeds_to_schedule += num_encoder_embeds
|
num_embeds_to_schedule += num_encoder_embeds
|
||||||
continue
|
continue
|
||||||
|
|
||||||
num_embeds_to_schedule += num_encoder_embeds
|
num_embeds_to_schedule += num_encoder_embeds
|
||||||
encoder_compute_budget -= num_encoder_embeds
|
encoder_compute_budget -= num_encoder_embeds
|
||||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
mm_hashes_to_schedule.add(item_identifier)
|
||||||
encoder_inputs_to_schedule.append(i)
|
encoder_inputs_to_schedule.append(i)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
|||||||
Reference in New Issue
Block a user