[Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475)

Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Sun Kim <sunytokki@gmail.com>
This commit is contained in:
Roger Wang
2025-12-16 14:18:17 -08:00
committed by GitHub
parent 9fec0e13d5
commit f5f51e5931
14 changed files with 306 additions and 130 deletions

View File

@@ -39,20 +39,26 @@ class EncoderCacheManager:
space for new embeddings.
Oldest cached embeddings with no request referenced will be first evicted.
NOTE: The EncoderCacheManager operates on the level of multimodal embeddings
instead of encoder tokens (i.e. all tokens that represent the multimodal data
in the input sequence). This means all break/text tokens in-between multimodal
embeddings are not considered with respect to the cache size and the number
of free slots.
Args:
cache_size: Limit the size of the cache, measured by the number of
tokens from the input sequence.
encoder embeddings from the input sequence.
Attributes:
cache_size: Total cache capacity in encoder tokens.
num_free_slots: Current available cache capacity in encoder tokens.
cache_size: Total cache capacity in encoder embeddings.
num_free_slots: Current available cache capacity in encoder embeddings.
num_freeable_slots: Capacity that can be immediately reclaimed by
evicting entries with zero references (in encoder tokens).
evicting entries with zero references (in encoder embeddings).
cached: Mapping from mm_hash to a set of request IDs that currently
reference the cached entry. If the set is empty, the entry exists
but is not referenced by any request and is eligible for
reclamation.
freeable: List of tuples (mm_hash, num_tokens) representing entries
freeable: List of tuples (mm_hash, num_encoder_embeds) representing entries
whose no current running request is needed and that can be freed to
make space when needed.
freed: List of mm_hash strings that were actually evicted since the
@@ -67,7 +73,7 @@ class EncoderCacheManager:
# mm_hash of mm_data => ids of requests that reference the mm_data
self.cached: dict[str, set[str]] = {}
# mm_hash of mm_data => num_encoder_tokens of the mm_data
# mm_hash of mm_data => num_encoder_embeds of the mm_data
self.freeable: OrderedDict[str, int] = OrderedDict()
self.freed: list[str] = []
@@ -93,8 +99,8 @@ class EncoderCacheManager:
# Cached but currently not referenced by any request
if not self.cached[mm_hash]:
num_tokens = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_tokens
num_encoder_embeds = self.freeable.pop(mm_hash)
self.num_freeable_slots -= num_encoder_embeds
self.cached[mm_hash].add(request.request_id)
return True
@@ -104,7 +110,7 @@ class EncoderCacheManager:
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
"""Check if there's sufficient cache space for a multimodal input.
If there is, return True and update EncoderCacheManager state.
@@ -121,9 +127,9 @@ class EncoderCacheManager:
Args:
request: The request containing the multimodal input.
input_id: Index of the multimodal input within the request.
encoder_compute_budget: Number of encoder tokens allowed to be
encoder_compute_budget: Number of encoder embeddings allowed to be
computed when this method is invoked.
num_tokens_to_schedule: Number of tokens already scheduled to be
num_embeds_to_schedule: Number of encoder embeddings already scheduled to be
allocated with cache space when this method is invoked.
Returns:
@@ -134,30 +140,30 @@ class EncoderCacheManager:
Note: This method does not allocate physical memory for the encoder
output but only the state of EncoderCacheManager.
"""
num_tokens = request.get_num_encoder_tokens(input_id)
num_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_embeds += num_embeds_to_schedule
# Enough free slots
if num_tokens <= self.num_free_slots:
if num_embeds <= self.num_free_slots:
return True
# Not enough reclaimable slots
if num_tokens > self.num_freeable_slots:
if num_embeds > self.num_freeable_slots:
return False
# Not enough free slots but enough reclaimable slots
# NOTE: Eviction takes place here, but physical memory is not freed
# until model runner is notified by the scheduler output.
while num_tokens > self.num_free_slots:
mm_hash, num_free_token = self.freeable.popitem(last=False)
while num_embeds > self.num_free_slots:
mm_hash, num_free_embeds = self.freeable.popitem(last=False)
del self.cached[mm_hash]
self.freed.append(mm_hash)
self.num_free_slots += num_free_token
self.num_free_slots += num_free_embeds
return True
def allocate(self, request: Request, input_id: int) -> None:
@@ -176,16 +182,16 @@ class EncoderCacheManager:
if mm_hash not in self.cached:
self.cached[mm_hash] = set()
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# NOTE: Encoder cache should always have enough space for encoder inputs
# that are scheduled since eviction takes place at can_allocate().
assert self.num_free_slots >= num_encoder_tokens
assert self.num_freeable_slots >= num_encoder_tokens
assert self.num_free_slots >= num_encoder_embeds
assert self.num_freeable_slots >= num_encoder_embeds
self.cached[mm_hash].add(request_id)
self.num_free_slots -= num_encoder_tokens
self.num_freeable_slots -= num_encoder_tokens
self.num_free_slots -= num_encoder_embeds
self.num_freeable_slots -= num_encoder_embeds
def get_cached_input_ids(self, request: Request) -> set[int]:
"""Get all cached multimodal input IDs for a request.
@@ -206,7 +212,7 @@ class EncoderCacheManager:
When the reference set for the corresponding `mm_hash` becomes empty,
the entry is appended to `freeable` and `num_freeable_slots` is
increased by the number of encoder tokens for that input.
increased by the number of encoder embeddings for that input.
The entry is NOT physically freed until capacity is needed (e.g., by
`can_allocate`).
@@ -218,9 +224,9 @@ class EncoderCacheManager:
return
self.cached[mm_hash].discard(req_id)
if not self.cached[mm_hash]:
num_tokens = request.get_num_encoder_tokens(input_id)
self.freeable[mm_hash] = num_tokens
self.num_freeable_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.freeable[mm_hash] = num_encoder_embeds
self.num_freeable_slots += num_encoder_embeds
def free(self, request: Request) -> None:
"""Free all encoder input cache reference held by *request*.
@@ -361,20 +367,20 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
request: Request,
input_id: int,
encoder_compute_budget: int,
num_tokens_to_schedule: int,
num_embeds_to_schedule: int,
) -> bool:
num_tokens = request.get_num_encoder_tokens(input_id)
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
# Not enough compute budget
if num_tokens > encoder_compute_budget:
if num_encoder_embeds > encoder_compute_budget:
return False
num_tokens += num_tokens_to_schedule
num_encoder_embeds += num_embeds_to_schedule
# Enough free slots
return num_tokens <= self.num_free_slots
return num_encoder_embeds <= self.num_free_slots
def allocate(self, request: Request, input_id: int) -> None:
num_encoder_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots -= num_encoder_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots -= num_encoder_embeds
mm_hash = request.mm_features[input_id].identifier
self.freed.append(mm_hash)
@@ -392,5 +398,5 @@ class EncoderDecoderCacheManager(EncoderCacheManager):
return freed
def free_encoder_input(self, request: Request, input_id: int) -> None:
num_tokens = request.get_num_encoder_tokens(input_id)
self.num_free_slots += num_tokens
num_encoder_embeds = request.get_num_encoder_embeds(input_id)
self.num_free_slots += num_encoder_embeds