[Core] Use key-only cache for BaseMultiModalProcessor (#23018)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,6 +11,7 @@ from vllm.inputs.parse import split_enc_dec_inputs
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.cache import processor_cache_from_config
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange
|
||||
from vllm.multimodal.processing import EncDecMultiModalProcessor
|
||||
from vllm.multimodal.utils import argsort_mm_positions
|
||||
@@ -18,7 +19,6 @@ from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.backend_lm_format_enforcer import (
|
||||
@@ -47,16 +47,17 @@ class Processor:
|
||||
|
||||
self.generation_config_fields = (
|
||||
self.model_config.try_get_generation_config())
|
||||
self.input_preprocessor = InputPreprocessor(self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.mm_input_cache_client = MultiModalInputCacheClient(
|
||||
self.model_config, mm_registry)
|
||||
self.mm_registry = mm_registry
|
||||
self.mm_processor_cache = processor_cache_from_config(
|
||||
vllm_config, mm_registry)
|
||||
|
||||
@property
|
||||
def mm_registry(self):
|
||||
return self.input_preprocessor.mm_registry
|
||||
self.input_preprocessor = InputPreprocessor(
|
||||
self.model_config,
|
||||
self.tokenizer,
|
||||
mm_registry,
|
||||
mm_processor_cache=self.mm_processor_cache,
|
||||
)
|
||||
|
||||
def _validate_logprobs(
|
||||
self,
|
||||
@@ -310,7 +311,7 @@ class Processor:
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
orig_sorted_mm_inputs = [
|
||||
sorted_mm_inputs = [
|
||||
decoder_mm_inputs[modality][idx]
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
@@ -323,11 +324,6 @@ class Processor:
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
|
||||
orig_sorted_mm_inputs,
|
||||
sorted_mm_hashes,
|
||||
)
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
@@ -415,3 +411,6 @@ class Processor:
|
||||
# TODO: Find out how many placeholder tokens are there so we can
|
||||
# check that chunked prefill does not truncate them
|
||||
# max_batch_len = self.scheduler_config.max_num_batched_tokens
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self.input_preprocessor.clear_cache()
|
||||
|
||||
Reference in New Issue
Block a user