[Core][MM] Optimize encoder cache manager by operating with embeddings only (#30475)

Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: Sun Kim <sunytokki@gmail.com>
This commit is contained in:
Roger Wang
2025-12-16 14:18:17 -08:00
committed by GitHub
parent 9fec0e13d5
commit f5f51e5931
14 changed files with 306 additions and 130 deletions

View File

@@ -355,11 +355,11 @@ class Scheduler(SchedulerInterface):
if preempted_encoder_inputs:
# Restore encoder compute budget if the preempted
# request had encoder inputs scheduled in this step.
num_tokens_to_restore = sum(
preempted_req.get_num_encoder_tokens(i)
num_embeds_to_restore = sum(
preempted_req.get_num_encoder_embeds(i)
for i in preempted_encoder_inputs
)
encoder_compute_budget += num_tokens_to_restore
encoder_compute_budget += num_embeds_to_restore
req_index -= 1
else:
preempted_req = self.running.pop()
@@ -911,10 +911,11 @@ class Scheduler(SchedulerInterface):
# multiple encoder inputs per request), we need to create temporary
# trackers for accounting at the encoder input level.
mm_hashes_to_schedule = set()
num_tokens_to_schedule = 0
num_embeds_to_schedule = 0
for i, mm_feature in enumerate(mm_features):
start_pos = mm_feature.mm_position.offset
num_encoder_tokens = mm_feature.mm_position.length
num_encoder_embeds = mm_feature.mm_position.get_num_embeds
# The encoder output is needed if the two ranges overlap:
# [num_computed_tokens, num_computed_tokens + num_new_tokens) and
@@ -970,9 +971,8 @@ class Scheduler(SchedulerInterface):
):
num_new_tokens = start_pos - num_computed_tokens
break
if not self.encoder_cache_manager.can_allocate(
request, i, encoder_compute_budget, num_tokens_to_schedule
request, i, encoder_compute_budget, num_embeds_to_schedule
):
# The encoder cache is full or the encoder budget is exhausted.
# NOTE(woosuk): We assume that the encoder input tokens should
@@ -992,14 +992,31 @@ class Scheduler(SchedulerInterface):
num_new_tokens = 0
break
# Calculate the number of embeddings to schedule in the current range
# of scheduled encoder placholder tokens.
start_idx_rel = max(0, num_computed_tokens - start_pos)
end_idx_rel = min(
num_encoder_tokens, num_computed_tokens + num_new_tokens - start_pos
)
curr_embeds_start, curr_embeds_end = (
mm_feature.mm_position.get_embeds_indices_in_range(
start_idx_rel,
end_idx_rel,
)
)
# There's no embeddings in the current range of encoder placeholder tokens
# so we can skip the encoder input.
if curr_embeds_end - curr_embeds_start == 0:
continue
if self.ec_connector is not None and remote_cache_has_item[i]:
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
external_load_encoder_input.append(i)
num_tokens_to_schedule += num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
continue
num_tokens_to_schedule += num_encoder_tokens
encoder_compute_budget -= num_encoder_tokens
num_embeds_to_schedule += num_encoder_embeds
encoder_compute_budget -= num_encoder_embeds
mm_hashes_to_schedule.add(request.mm_features[i].identifier)
encoder_inputs_to_schedule.append(i)