[EC Connector] Optimize remote cache check in scheduler (#32585)

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
knlnguyen1802
2026-01-22 11:30:59 +08:00
committed by GitHub
parent c5487e2b96
commit 378385b90c
5 changed files with 76 additions and 57 deletions

View File

@@ -182,19 +182,19 @@ class ECConnectorBase(ABC):
# ==============================
@abstractmethod
def has_caches(
def has_cache_item(
self,
request: "Request",
) -> list[bool]:
identifier: str,
) -> bool:
"""
Check if encoder cache exists for each mm data of requests
Check if a single encoder cache exists
Args:
request (Request): the request object.
identifier (str): the identifier of the media.
Returns:
A list bool where ith value is True if cache exist for
ith mm_data of requests
A bool where value is True if cache exist for
the media
"""
pass

View File

@@ -117,23 +117,20 @@ class ECExampleConnector(ECConnectorBase):
safetensors.torch.save_file(tensors, filename)
logger.debug("Save cache successful for mm_hash %s", mm_hash)
def has_caches(
def has_cache_item(
self,
request: "Request",
) -> list[bool]:
identifier: str,
) -> bool:
"""
Check if cache exist externally for each mm_data of request
Check if cache exist externally for the media
Args:
request (Request): the request object.
identifier (str): the identifier of the media.
Returns:
List of bool indicate that ith mm_data exist in cache or not
Bool indicate that media exists in cache or not
"""
result = []
for feature in request.mm_features:
result.append(self._found_match_for_mm_data(feature.identifier))
return result
return self._found_match_for_mm_data(identifier)
def update_state_after_alloc(
self,

View File

@@ -947,9 +947,6 @@ class Scheduler(SchedulerInterface):
assert len(mm_features) > 0
external_load_encoder_input = []
# Check remote cache first
if self.ec_connector is not None:
remote_cache_has_item = self.ec_connector.has_caches(request)
# NOTE: since scheduler operates on the request level (possibly with
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
@@ -959,6 +956,7 @@ class Scheduler(SchedulerInterface):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
item_identifier = mm_feature.identifier
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@@ -993,7 +991,7 @@ class Scheduler(SchedulerInterface):
if not self.is_encoder_decoder:
# We are not using the encoder cache for encoder-decoder models,
# yet.
if request.mm_features[i].identifier in mm_hashes_to_schedule:
if item_identifier in mm_hashes_to_schedule:
# The same encoder input has already been scheduled in the
# current step.
continue
@@ -1051,15 +1049,17 @@ class Scheduler(SchedulerInterface):
if curr_embeds_end - curr_embeds_start == 0:
continue
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
if self.ec_connector is not None and self.ec_connector.has_cache_item(
item_identifier
):
mm_hashes_to_schedule.add(item_identifier)
external_load_encoder_input.append(i)
num_embeds_to_schedule += num_encoder_embeds
continue
num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
mm_hashes_to_schedule.add(item_identifier)
encoder_inputs_to_schedule.append(i)
return (