[EC Connector] Optimize remote cache check in scheduler (#32585)
Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
This commit is contained in:
@@ -182,19 +182,19 @@ class ECConnectorBase(ABC):
|
||||
# ==============================
|
||||
|
||||
@abstractmethod
|
||||
def has_caches(
|
||||
def has_cache_item(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> list[bool]:
|
||||
identifier: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if encoder cache exists for each mm data of requests
|
||||
Check if a single encoder cache exists
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
identifier (str): the identifier of the media.
|
||||
|
||||
Returns:
|
||||
A list bool where ith value is True if cache exist for
|
||||
ith mm_data of requests
|
||||
A bool where value is True if cache exist for
|
||||
the media
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -117,23 +117,20 @@ class ECExampleConnector(ECConnectorBase):
|
||||
safetensors.torch.save_file(tensors, filename)
|
||||
logger.debug("Save cache successful for mm_hash %s", mm_hash)
|
||||
|
||||
def has_caches(
|
||||
def has_cache_item(
|
||||
self,
|
||||
request: "Request",
|
||||
) -> list[bool]:
|
||||
identifier: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if cache exist externally for each mm_data of request
|
||||
Check if cache exist externally for the media
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
identifier (str): the identifier of the media.
|
||||
|
||||
Returns:
|
||||
List of bool indicate that ith mm_data exist in cache or not
|
||||
Bool indicate that media exists in cache or not
|
||||
"""
|
||||
result = []
|
||||
for feature in request.mm_features:
|
||||
result.append(self._found_match_for_mm_data(feature.identifier))
|
||||
return result
|
||||
return self._found_match_for_mm_data(identifier)
|
||||
|
||||
def update_state_after_alloc(
|
||||
self,
|
||||
|
||||
@@ -947,9 +947,6 @@ class Scheduler(SchedulerInterface):
|
||||
assert len(mm_features) > 0
|
||||
external_load_encoder_input = []
|
||||
|
||||
# Check remote cache first
|
||||
if self.ec_connector is not None:
|
||||
remote_cache_has_item = self.ec_connector.has_caches(request)
|
||||
# NOTE: since scheduler operates on the request level (possibly with
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
@@ -959,6 +956,7 @@ class Scheduler(SchedulerInterface):
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_encoder_tokens = mm_feature.mm_position.length
|
||||
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||
item_identifier = mm_feature.identifier
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
@@ -993,7 +991,7 @@ class Scheduler(SchedulerInterface):
|
||||
if not self.is_encoder_decoder:
|
||||
# We are not using the encoder cache for encoder-decoder models,
|
||||
# yet.
|
||||
if request.mm_features[i].identifier in mm_hashes_to_schedule:
|
||||
if item_identifier in mm_hashes_to_schedule:
|
||||
# The same encoder input has already been scheduled in the
|
||||
# current step.
|
||||
continue
|
||||
@@ -1051,15 +1049,17 @@ class Scheduler(SchedulerInterface):
|
||||
if curr_embeds_end - curr_embeds_start == 0:
|
||||
continue
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
if self.ec_connector is not None and self.ec_connector.has_cache_item(
|
||||
item_identifier
|
||||
):
|
||||
mm_hashes_to_schedule.add(item_identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
continue
|
||||
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
encoder_compute_budget -= num_encoder_embeds
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
mm_hashes_to_schedule.add(item_identifier)
|
||||
encoder_inputs_to_schedule.append(i)
|
||||
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user