[Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475)
Signed-off-by: Roger Wang <hey@rogerw.io> Co-authored-by: Sun Kim <sunytokki@gmail.com>
This commit is contained in:
@@ -39,20 +39,26 @@ class EncoderCacheManager:
|
||||
space for new embeddings.
|
||||
Oldest cached embeddings with no request referenced will be first evicted.
|
||||
|
||||
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
|
||||
instead of encoder tokens (i.e. all tokens that represent the multimodal data
|
||||
in the input sequence). This means all break/text tokens in-between multimodal
|
||||
embeddings are not considered with respect to the cache size and the number
|
||||
of free slots.
|
||||
|
||||
Args:
|
||||
cache_size: Limit the size of the cache, measured by the number of
|
||||
tokens from the input sequence.
|
||||
encoder embeddings from the input sequence.
|
||||
|
||||
Attributes:
|
||||
cache_size: Total cache capacity in encoder tokens.
|
||||
num_free_slots: Current available cache capacity in encoder tokens.
|
||||
cache_size: Total cache capacity in encoder embeddings.
|
||||
num_free_slots: Current available cache capacity in encoder embeddings.
|
||||
num_freeable_slots: Capacity that can be immediately reclaimed by
|
||||
evicting entries with zero references (in encoder tokens).
|
||||
evicting entries with zero references (in encoder embeddings).
|
||||
cached: Mapping from mm_hash to a set of request IDs that currently
|
||||
reference the cached entry. If the set is empty, the entry exists
|
||||
but is not referenced by any request and is eligible for
|
||||
reclamation.
|
||||
freeable: List of tuples (mm_hash, num_tokens) representing entries
|
||||
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
|
||||
whose no current running request is needed and that can be freed to
|
||||
make space when needed.
|
||||
freed: List of mm_hash strings that were actually evicted since the
|
||||
@@ -67,7 +73,7 @@ class EncoderCacheManager:
|
||||
# mm_hash of mm_data => ids of requests that reference the mm_data
|
||||
self.cached: dict[str, set[str]] = {}
|
||||
|
||||
# mm_hash of mm_data => num_encoder_tokens of the mm_data
|
||||
# mm_hash of mm_data => num_encoder_embeds of the mm_data
|
||||
self.freeable: OrderedDict[str, int] = OrderedDict()
|
||||
self.freed: list[str] = []
|
||||
|
||||
@@ -93,8 +99,8 @@ class EncoderCacheManager:
|
||||
|
||||
# Cached but currently not referenced by any request
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_tokens
|
||||
num_encoder_embeds = self.freeable.pop(mm_hash)
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request.request_id)
|
||||
return True
|
||||
@@ -104,7 +110,7 @@ class EncoderCacheManager:
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
"""Check if there's sufficient cache space for a multimodal input.
|
||||
If there is, return True and update EncoderCacheManager state.
|
||||
@@ -121,9 +127,9 @@ class EncoderCacheManager:
|
||||
Args:
|
||||
request: The request containing the multimodal input.
|
||||
input_id: Index of the multimodal input within the request.
|
||||
encoder_compute_budget: Number of encoder tokens allowed to be
|
||||
encoder_compute_budget: Number of encoder embeddings allowed to be
|
||||
computed when this method is invoked.
|
||||
num_tokens_to_schedule: Number of tokens already scheduled to be
|
||||
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
|
||||
allocated with cache space when this method is invoked.
|
||||
|
||||
Returns:
|
||||
@@ -134,30 +140,30 @@ class EncoderCacheManager:
|
||||
Note: This method does not allocate physical memory for the encoder
|
||||
output but only the state of EncoderCacheManager.
|
||||
"""
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_embeds += num_embeds_to_schedule
|
||||
|
||||
# Enough free slots
|
||||
if num_tokens <= self.num_free_slots:
|
||||
if num_embeds <= self.num_free_slots:
|
||||
return True
|
||||
|
||||
# Not enough reclaimable slots
|
||||
if num_tokens > self.num_freeable_slots:
|
||||
if num_embeds > self.num_freeable_slots:
|
||||
return False
|
||||
|
||||
# Not enough free slots but enough reclaimable slots
|
||||
# NOTE: Eviction takes place here, but physical memory is not freed
|
||||
# until model runner is notified by the scheduler output.
|
||||
while num_tokens > self.num_free_slots:
|
||||
mm_hash, num_free_token = self.freeable.popitem(last=False)
|
||||
while num_embeds > self.num_free_slots:
|
||||
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
|
||||
del self.cached[mm_hash]
|
||||
self.freed.append(mm_hash)
|
||||
self.num_free_slots += num_free_token
|
||||
self.num_free_slots += num_free_embeds
|
||||
return True
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
@@ -176,16 +182,16 @@ class EncoderCacheManager:
|
||||
if mm_hash not in self.cached:
|
||||
self.cached[mm_hash] = set()
|
||||
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
|
||||
# NOTE: Encoder cache should always have enough space for encoder inputs
|
||||
# that are scheduled since eviction takes place at can_allocate().
|
||||
assert self.num_free_slots >= num_encoder_tokens
|
||||
assert self.num_freeable_slots >= num_encoder_tokens
|
||||
assert self.num_free_slots >= num_encoder_embeds
|
||||
assert self.num_freeable_slots >= num_encoder_embeds
|
||||
|
||||
self.cached[mm_hash].add(request_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
self.num_freeable_slots -= num_encoder_tokens
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
self.num_freeable_slots -= num_encoder_embeds
|
||||
|
||||
def get_cached_input_ids(self, request: Request) -> set[int]:
|
||||
"""Get all cached multimodal input IDs for a request.
|
||||
@@ -206,7 +212,7 @@ class EncoderCacheManager:
|
||||
|
||||
When the reference set for the corresponding `mm_hash` becomes empty,
|
||||
the entry is appended to `freeable` and `num_freeable_slots` is
|
||||
increased by the number of encoder tokens for that input.
|
||||
increased by the number of encoder embeddings for that input.
|
||||
|
||||
The entry is NOT physically freed until capacity is needed (e.g., by
|
||||
`can_allocate`).
|
||||
@@ -218,9 +224,9 @@ class EncoderCacheManager:
|
||||
return
|
||||
self.cached[mm_hash].discard(req_id)
|
||||
if not self.cached[mm_hash]:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.freeable[mm_hash] = num_tokens
|
||||
self.num_freeable_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.freeable[mm_hash] = num_encoder_embeds
|
||||
self.num_freeable_slots += num_encoder_embeds
|
||||
|
||||
def free(self, request: Request) -> None:
|
||||
"""Free all encoder input cache reference held by *request*.
|
||||
@@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
request: Request,
|
||||
input_id: int,
|
||||
encoder_compute_budget: int,
|
||||
num_tokens_to_schedule: int,
|
||||
num_embeds_to_schedule: int,
|
||||
) -> bool:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
# Not enough compute budget
|
||||
if num_tokens > encoder_compute_budget:
|
||||
if num_encoder_embeds > encoder_compute_budget:
|
||||
return False
|
||||
|
||||
num_tokens += num_tokens_to_schedule
|
||||
num_encoder_embeds += num_embeds_to_schedule
|
||||
# Enough free slots
|
||||
return num_tokens <= self.num_free_slots
|
||||
return num_encoder_embeds <= self.num_free_slots
|
||||
|
||||
def allocate(self, request: Request, input_id: int) -> None:
|
||||
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots -= num_encoder_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots -= num_encoder_embeds
|
||||
|
||||
mm_hash = request.mm_features[input_id].identifier
|
||||
self.freed.append(mm_hash)
|
||||
@@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
|
||||
return freed
|
||||
|
||||
def free_encoder_input(self, request: Request, input_id: int) -> None:
|
||||
num_tokens = request.get_num_encoder_tokens(input_id)
|
||||
self.num_free_slots += num_tokens
|
||||
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
|
||||
self.num_free_slots += num_encoder_embeds
|
||||
|
||||
@@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface):
|
||||
if preempted_encoder_inputs:
|
||||
# Restore encoder compute budget if the preempted
|
||||
# request had encoder inputs scheduled in this step.
|
||||
num_tokens_to_restore = sum(
|
||||
preempted_req.get_num_encoder_tokens(i)
|
||||
num_embeds_to_restore = sum(
|
||||
preempted_req.get_num_encoder_embeds(i)
|
||||
for i in preempted_encoder_inputs
|
||||
)
|
||||
encoder_compute_budget += num_tokens_to_restore
|
||||
encoder_compute_budget += num_embeds_to_restore
|
||||
req_index -= 1
|
||||
else:
|
||||
preempted_req = self.running.pop()
|
||||
@@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
|
||||
# multiple encoder inputs per request), we need to create temporary
|
||||
# trackers for accounting at the encoder input level.
|
||||
mm_hashes_to_schedule = set()
|
||||
num_tokens_to_schedule = 0
|
||||
num_embeds_to_schedule = 0
|
||||
for i, mm_feature in enumerate(mm_features):
|
||||
start_pos = mm_feature.mm_position.offset
|
||||
num_encoder_tokens = mm_feature.mm_position.length
|
||||
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
|
||||
|
||||
# The encoder output is needed if the two ranges overlap:
|
||||
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
|
||||
@@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
|
||||
):
|
||||
num_new_tokens = start_pos - num_computed_tokens
|
||||
break
|
||||
|
||||
if not self.encoder_cache_manager.can_allocate(
|
||||
request, i, encoder_compute_budget, num_tokens_to_schedule
|
||||
request, i, encoder_compute_budget, num_embeds_to_schedule
|
||||
):
|
||||
# The encoder cache is full or the encoder budget is exhausted.
|
||||
# NOTE(woosuk): We assume that the encoder input tokens should
|
||||
@@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
|
||||
num_new_tokens = 0
|
||||
break
|
||||
|
||||
# Calculate the number of embeddings to schedule in the current range
|
||||
# of scheduled encoder placholder tokens.
|
||||
start_idx_rel = max(0, num_computed_tokens - start_pos)
|
||||
end_idx_rel = min(
|
||||
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
|
||||
)
|
||||
curr_embeds_start, curr_embeds_end = (
|
||||
mm_feature.mm_position.get_embeds_indices_in_range(
|
||||
start_idx_rel,
|
||||
end_idx_rel,
|
||||
)
|
||||
)
|
||||
# There's no embeddings in the current range of encoder placeholder tokens
|
||||
# so we can skip the encoder input.
|
||||
if curr_embeds_end - curr_embeds_start == 0:
|
||||
continue
|
||||
|
||||
if self.ec_connector is not None and remote_cache_has_item[i]:
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
external_load_encoder_input.append(i)
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
continue
|
||||
|
||||
num_tokens_to_schedule += num_encoder_tokens
|
||||
encoder_compute_budget -= num_encoder_tokens
|
||||
num_embeds_to_schedule += num_encoder_embeds
|
||||
encoder_compute_budget -= num_encoder_embeds
|
||||
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
|
||||
encoder_inputs_to_schedule.append(i)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user