[Core] Use key-only cache for BaseMultiModalProcessor (#23018)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-27 14:19:13 +08:00
committed by GitHub
parent 8dbf6ed7be
commit 69244e67e6
29 changed files with 954 additions and 394 deletions

View File

@@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.cache import processor_cache_from_config
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
from vllm.multimodal.processing import EncDecMultiModalProcessor
from vllm.multimodal.utils import argsort_mm_positions
@@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.backend_lm_format_enforcer import (
@@ -47,16 +47,17 @@ class Processor:
self.generation_config_fields = (
self.model_config.try_get_generation_config())
self.input_preprocessor = InputPreprocessor(self.model_config,
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MultiModalInputCacheClient(
self.model_config, mm_registry)
self.mm_registry = mm_registry
self.mm_processor_cache = processor_cache_from_config(
vllm_config, mm_registry)
@property
def mm_registry(self):
return self.input_preprocessor.mm_registry
self.input_preprocessor = InputPreprocessor(
self.model_config,
self.tokenizer,
mm_registry,
mm_processor_cache=self.mm_processor_cache,
)
def _validate_logprobs(
self,
@@ -310,7 +311,7 @@ class Processor:
# in the input sequence.
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
orig_sorted_mm_inputs = [
sorted_mm_inputs = [
decoder_mm_inputs[modality][idx]
for modality, idx in sorted_mm_idxs
]
@@ -323,11 +324,6 @@ class Processor:
for modality, idx in sorted_mm_idxs
]
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
orig_sorted_mm_inputs,
sorted_mm_hashes,
)
return decoder_inputs.get("prompt"), EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
@@ -415,3 +411,6 @@ class Processor:
# TODO: Find out how many placeholder tokens are there so we can
# check that chunked prefill does not truncate them
# max_batch_len = self.scheduler_config.max_num_batched_tokens
def clear_cache(self) -> None:
self.input_preprocessor.clear_cache()