[Core][Multimodal] Track encode cache entries by mm_hash and enable embedding sharing between requests (#22711)

Signed-off-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Signed-off-by: Roger Wang <hey@rogerw.io>
Co-authored-by: knlnguyen1802 <knlnguyen1802@gmail.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
Chenguang Zheng
2025-08-25 15:41:17 +08:00
committed by GitHub
parent 712d0f88d8
commit d765cf01fe
12 changed files with 365 additions and 154 deletions

View File

@@ -176,8 +176,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.attn_groups: list[list[AttentionGroup]] = []
# self.kv_cache_config: KVCacheConfig
# req_id -> (input_id -> encoder_output)
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
# mm_hash -> encoder_output
self.encoder_cache: dict[str, torch.Tensor] = {}
self.use_aux_hidden_state_outputs = False
# Set up speculative decoding.
@@ -436,7 +436,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Remove finished requests from the cached states.
for req_id in scheduler_output.finished_req_ids:
self.requests.pop(req_id, None)
self.encoder_cache.pop(req_id, None)
# Remove the finished requests from the persistent batch.
# NOTE(woosuk): There could be an edge case where finished_req_ids and
# scheduled_req_ids overlap. This happens when a request is aborted and
@@ -447,12 +446,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.input_batch.remove_request(req_id)
# Free the cached encoder outputs.
for req_id, input_id in scheduler_output.free_encoder_input_ids:
encoder_outputs = self.encoder_cache.get(req_id)
if encoder_outputs is not None:
encoder_outputs.pop(input_id, None)
if not encoder_outputs:
self.encoder_cache.pop(req_id, None)
for mm_hash in scheduler_output.free_encoder_mm_hashes:
self.encoder_cache.pop(mm_hash, None)
# Remove the unscheduled requests from the persistent batch.
# NOTE(woosuk): The unscheduled requests are either preempted requests
@@ -496,6 +491,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
prompt_token_ids=new_req_data.prompt_token_ids,
mm_kwargs=new_req_data.mm_kwargs,
mm_positions=new_req_data.mm_positions,
mm_hashes=new_req_data.mm_hashes,
sampling_params=sampling_params,
pooling_params=pooling_params,
generator=generator,
@@ -1161,17 +1157,18 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
if not scheduled_encoder_inputs:
return
# Batch the multi-modal inputs.
mm_kwargs = list[MultiModalKwargsItem]()
req_ids_pos = list[tuple[str, int, PlaceholderRange]]()
# list of tuple (mm_hash, position_info)
mm_hashes_pos = list[tuple[str, PlaceholderRange]]()
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
req_state = self.requests[req_id]
for mm_input_id in encoder_input_ids:
mm_hash = req_state.mm_hashes[mm_input_id]
mm_kwargs.append(req_state.mm_kwargs[mm_input_id])
req_ids_pos.append(
(req_id, mm_input_id, req_state.mm_positions[mm_input_id]))
mm_hashes_pos.append(
(mm_hash, req_state.mm_positions[mm_input_id]))
# Batch mm inputs as much as we can: if a request in the batch has
# multiple modalities or a different modality than the previous one,
@@ -1204,15 +1201,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
for output in curr_group_outputs:
encoder_outputs.append(output)
# Cache the encoder outputs.
for (req_id, input_id, pos_info), output in zip(
req_ids_pos,
encoder_outputs,
):
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}
self.encoder_cache[req_id][input_id] = scatter_mm_placeholders(
# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
@@ -1230,6 +1221,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_computed_tokens = \
req_state.num_computed_tokens + shift_computed_tokens
mm_positions = req_state.mm_positions
mm_hashes = req_state.mm_hashes
for i, pos_info in enumerate(mm_positions):
start_pos = pos_info.offset
num_encoder_tokens = pos_info.length
@@ -1249,11 +1241,14 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(
num_computed_tokens - start_pos + num_scheduled_tokens,
num_encoder_tokens)
num_encoder_tokens,
)
assert start_idx < end_idx
assert req_id in self.encoder_cache
assert i in self.encoder_cache[req_id]
encoder_output = self.encoder_cache[req_id][i]
mm_hash = mm_hashes[i]
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
f"Encoder cache miss for {mm_hash}."
if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]