[V1] VLM preprocessor hashing (#11020)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: Alexander Matveev <alexm@neuralmagic.com> Co-authored-by: Michael Goin <michael@neuralmagic.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
committed by
GitHub
parent
452a723bf2
commit
4e11683368
@@ -147,6 +147,9 @@ class ModelConfig:
|
||||
HuggingFace config.
|
||||
mm_processor_kwargs: Arguments to be forwarded to the model's processor
|
||||
for multi-modal data, e.g., image processor.
|
||||
mm_cache_preprocessor: If true, then enables caching of the multi-modal
|
||||
preprocessor/mapper. Otherwise, the mapper executes each time, and
|
||||
for better performance consider enabling frontend process.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
@@ -185,6 +188,7 @@ class ModelConfig:
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
|
||||
mm_cache_preprocessor: bool = False,
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None) -> None:
|
||||
self.model = model
|
||||
@@ -251,6 +255,7 @@ class ModelConfig:
|
||||
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
|
||||
self.use_async_output_proc = use_async_output_proc
|
||||
self.mm_processor_kwargs = mm_processor_kwargs
|
||||
self.mm_cache_preprocessor = mm_cache_preprocessor
|
||||
|
||||
# Set enforce_eager to False if the value is unset.
|
||||
if self.enforce_eager is None:
|
||||
@@ -2686,9 +2691,10 @@ class VllmConfig:
|
||||
f"enable_prefix_caching={self.cache_config.enable_prefix_caching}, "
|
||||
f"chunked_prefill_enabled={self.scheduler_config.chunked_prefill_enabled}, " # noqa
|
||||
f"use_async_output_proc={self.model_config.use_async_output_proc}, "
|
||||
f"mm_cache_preprocessor={self.model_config.mm_cache_preprocessor!r}, " # noqa
|
||||
f"mm_processor_kwargs={self.model_config.mm_processor_kwargs}, "
|
||||
f"pooler_config={self.model_config.pooler_config!r},"
|
||||
f" compilation_config={self.compilation_config!r}")
|
||||
f"pooler_config={self.model_config.pooler_config!r}, "
|
||||
f"compilation_config={self.compilation_config!r}")
|
||||
|
||||
|
||||
_current_vllm_config: Optional[VllmConfig] = None
|
||||
|
||||
@@ -143,6 +143,7 @@ class EngineArgs:
|
||||
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
|
||||
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
mm_cache_preprocessor: bool = False
|
||||
enable_lora: bool = False
|
||||
enable_lora_bias: bool = False
|
||||
max_loras: int = 1
|
||||
@@ -593,6 +594,12 @@ class EngineArgs:
|
||||
type=json.loads,
|
||||
help=('Overrides for the multimodal input mapping/processing, '
|
||||
'e.g., image processor. For example: {"num_crops": 4}.'))
|
||||
parser.add_argument(
|
||||
'--mm-cache-preprocessor',
|
||||
action='store_true',
|
||||
help='If true, then enables caching of the multi-modal '
|
||||
'preprocessor/mapper. Otherwise, the mapper executes each time'
|
||||
', and for better performance consider enabling frontend process.')
|
||||
|
||||
# LoRA related configs
|
||||
parser.add_argument('--enable-lora',
|
||||
@@ -965,6 +972,7 @@ class EngineArgs:
|
||||
use_async_output_proc=not self.disable_async_output_proc,
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
mm_cache_preprocessor=self.mm_cache_preprocessor,
|
||||
override_neuron_config=self.override_neuron_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
)
|
||||
|
||||
@@ -35,7 +35,8 @@ class EngineCoreRequest:
|
||||
# always be tokenized?
|
||||
prompt: Optional[str]
|
||||
prompt_token_ids: List[int]
|
||||
mm_inputs: Optional[List[MultiModalKwargs]]
|
||||
mm_inputs: Optional[List[Optional[MultiModalKwargs]]]
|
||||
mm_hashes: Optional[List[Optional[str]]]
|
||||
mm_placeholders: Optional[MultiModalPlaceholderDict]
|
||||
sampling_params: SamplingParams
|
||||
eos_token_id: Optional[int]
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs,
|
||||
EngineCoreProfile, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import PickleEncoder
|
||||
@@ -55,9 +55,6 @@ class EngineCore:
|
||||
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
|
||||
# Set up multimodal input mapper (e.g., convert PIL images to tensors).
|
||||
self.mm_input_mapper = MMInputMapper(vllm_config.model_config)
|
||||
|
||||
# Setup scheduler.
|
||||
self.scheduler = Scheduler(vllm_config.scheduler_config,
|
||||
vllm_config.cache_config,
|
||||
@@ -65,6 +62,8 @@ class EngineCore:
|
||||
|
||||
self._last_logging_time = time.time()
|
||||
|
||||
self.mm_input_mapper_server = MMInputMapperServer()
|
||||
|
||||
def _initialize_kv_caches(self,
|
||||
cache_config: CacheConfig) -> Tuple[int, int]:
|
||||
start = time.time()
|
||||
@@ -88,7 +87,18 @@ class EngineCore:
|
||||
|
||||
def add_request(self, request: EngineCoreRequest):
|
||||
"""Add request to the scheduler."""
|
||||
|
||||
if request.mm_hashes is not None:
|
||||
# Here, if hash exists for an image, then it will be fetched
|
||||
# from the cache, else it will be added to the cache.
|
||||
# Note that the cache here is mirrored with the client side of the
|
||||
# MM mapper, so anything that has a hash must have a HIT cache
|
||||
# entry here as well.
|
||||
request.mm_inputs = self.mm_input_mapper_server.process_inputs(
|
||||
request.mm_inputs, request.mm_hashes)
|
||||
|
||||
req = Request.from_engine_core_request(request)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def abort_requests(self, request_ids: List[str]):
|
||||
|
||||
@@ -1,11 +1,35 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import PIL
|
||||
from blake3 import blake3
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict,
|
||||
MultiModalKwargs, MultiModalRegistry)
|
||||
from vllm.v1.utils import LRUDictCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# The idea of MM preprocessor caching is based on having a client and a server,
|
||||
# where the client executes in the frontend process (=P0) and the server in the
|
||||
# core process (=P1).
|
||||
#
|
||||
# -- Client: Executes the MM mapper and performs caching of the results.
|
||||
# -- Server: Performs caching of the results
|
||||
#
|
||||
# The caching for both client and server is mirrored/similar, and this allows us
|
||||
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
||||
# client (=P0) and server (=P1) processes.
|
||||
|
||||
# Both Client and Server must use the same cache size
|
||||
# (to perform mirrored caching)
|
||||
# TODO: Tune the MM cache size
|
||||
MM_CACHE_SIZE = 256
|
||||
|
||||
|
||||
class MMInputMapper:
|
||||
class MMInputMapperClient:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -18,23 +42,131 @@ class MMInputMapper:
|
||||
model_config)
|
||||
self.mm_registry.init_mm_limits_per_prompt(model_config)
|
||||
|
||||
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
|
||||
|
||||
# DEBUG: Set to None to disable
|
||||
self.mm_debug_cache_hit_ratio_steps = None
|
||||
self.mm_cache_hits = 0
|
||||
self.mm_cache_total = 0
|
||||
|
||||
def cache_hit_ratio(self, steps) -> float:
|
||||
if self.mm_cache_total > 0 and self.mm_cache_total % steps == 0:
|
||||
logger.debug("MMInputMapper: cache_hit_ratio = %.2f ",
|
||||
self.mm_cache_hits / self.mm_cache_total)
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_hashes: Optional[List[str]],
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]],
|
||||
precomputed_mm_inputs: Optional[List[MultiModalKwargs]],
|
||||
) -> List[MultiModalKwargs]:
|
||||
if precomputed_mm_inputs is None:
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
image_inputs = [image_inputs]
|
||||
num_inputs = len(image_inputs)
|
||||
else:
|
||||
num_inputs = len(precomputed_mm_inputs)
|
||||
|
||||
# Check if hash is enabled
|
||||
use_hash = mm_hashes is not None
|
||||
if use_hash:
|
||||
assert num_inputs == len(
|
||||
mm_hashes), "num_inputs = {} len(mm_hashes) = {}".format(
|
||||
num_inputs, len(mm_hashes))
|
||||
|
||||
# Process each image input separately, so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
# Apply caching (if enabled) and reuse precomputed inputs (if provided)
|
||||
ret_hashes = [] if use_hash else None
|
||||
ret_inputs: List[MultiModalKwargs] = []
|
||||
for input_id in range(num_inputs):
|
||||
if self.mm_debug_cache_hit_ratio_steps is not None:
|
||||
self.cache_hit_ratio(self.mm_debug_cache_hit_ratio_steps)
|
||||
|
||||
mm_hash = None
|
||||
mm_input = None
|
||||
if use_hash:
|
||||
mm_hash = mm_hashes[input_id]
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
|
||||
self.mm_cache_total += 1
|
||||
if mm_input is None:
|
||||
if precomputed_mm_inputs is not None:
|
||||
# Reuse precomputed input (for merged preprocessor)
|
||||
mm_input = precomputed_mm_inputs[input_id]
|
||||
else:
|
||||
# Apply MM mapper
|
||||
mm_input = self.multi_modal_input_mapper(
|
||||
{"image": [image_inputs[input_id]]},
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if use_hash:
|
||||
# Add to cache
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
else:
|
||||
self.mm_cache_hits += 1
|
||||
mm_input = None # Avoids sending mm_input to Server
|
||||
|
||||
if use_hash:
|
||||
ret_hashes.append(mm_hash)
|
||||
ret_inputs.append(mm_input)
|
||||
|
||||
return ret_inputs, ret_hashes
|
||||
|
||||
|
||||
class MMInputMapperServer:
|
||||
|
||||
def __init__(self, ):
|
||||
self.mm_cache = LRUDictCache(MM_CACHE_SIZE)
|
||||
|
||||
def process_inputs(
|
||||
self,
|
||||
mm_inputs: List[Optional[MultiModalKwargs]],
|
||||
mm_hashes: List[Optional[str]],
|
||||
) -> List[MultiModalKwargs]:
|
||||
assert len(mm_inputs) == len(mm_hashes)
|
||||
|
||||
full_mm_inputs = []
|
||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||
if mm_input is None:
|
||||
mm_input = self.mm_cache.get(mm_hash)
|
||||
assert mm_input is not None
|
||||
else:
|
||||
self.mm_cache.put(mm_hash, mm_input)
|
||||
|
||||
full_mm_inputs.append(mm_input)
|
||||
|
||||
return full_mm_inputs
|
||||
|
||||
|
||||
class MMHasher:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def hash(self, prompt: PromptType) -> Optional[List[str]]:
|
||||
if "multi_modal_data" not in prompt:
|
||||
return None
|
||||
|
||||
mm_data = prompt["multi_modal_data"]
|
||||
image_inputs = mm_data["image"]
|
||||
if not isinstance(image_inputs, list):
|
||||
image_inputs = [image_inputs]
|
||||
assert len(image_inputs) > 0
|
||||
|
||||
# Process each image input separately so that later we can schedule
|
||||
# them in a fine-grained manner.
|
||||
mm_inputs: List[MultiModalKwargs] = []
|
||||
num_images = len(image_inputs)
|
||||
for i in range(num_images):
|
||||
mm_input = self.multi_modal_input_mapper(
|
||||
{"image": image_inputs[i]},
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
)
|
||||
mm_inputs.append(mm_input)
|
||||
return mm_inputs
|
||||
ret = []
|
||||
for image in image_inputs:
|
||||
assert isinstance(image, PIL.Image.Image)
|
||||
|
||||
# Convert image to bytes
|
||||
bytes = image.tobytes()
|
||||
|
||||
# Hash image bytes
|
||||
hasher = blake3()
|
||||
hasher.update(bytes)
|
||||
ret.append(hasher.hexdigest())
|
||||
|
||||
return ret
|
||||
|
||||
@@ -15,7 +15,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.config import try_get_generation_config
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import DetokenizerRequest, EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapper
|
||||
from vllm.v1.engine.mm_input_mapper import MMHasher, MMInputMapperClient
|
||||
|
||||
|
||||
class Processor:
|
||||
@@ -42,7 +42,11 @@ class Processor:
|
||||
model_config)
|
||||
|
||||
# Multi-modal (huggingface) input mapper
|
||||
self.mm_input_mapper = MMInputMapper(model_config)
|
||||
self.mm_input_mapper_client = MMInputMapperClient(model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.mm_hasher = MMHasher(
|
||||
) if model_config.mm_cache_preprocessor else None
|
||||
|
||||
# TODO: run in an ThreadpoolExecutor or BackgroundProcess.
|
||||
# This ideally should releases the GIL, so we should not block the
|
||||
@@ -71,6 +75,11 @@ class Processor:
|
||||
assert priority == 0, "vLLM V1 does not support priority at the moment."
|
||||
assert trace_headers is None, "vLLM V1 does not support tracing yet."
|
||||
|
||||
# Compute MM hashes (if enabled)
|
||||
mm_hashes = None
|
||||
if self.mm_hasher is not None:
|
||||
mm_hashes = self.mm_hasher.hash(prompt)
|
||||
|
||||
# Process inputs.
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
@@ -101,16 +110,17 @@ class Processor:
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
|
||||
# Preprocess multi-modal data
|
||||
if len(decoder_inputs.multi_modal_data) == 0:
|
||||
mm_inputs = None
|
||||
elif isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
|
||||
mm_inputs = [decoder_inputs.multi_modal_data]
|
||||
else:
|
||||
mm_inputs = self.mm_input_mapper.process_inputs(
|
||||
decoder_inputs.multi_modal_data,
|
||||
decoder_inputs.mm_processor_kwargs,
|
||||
)
|
||||
# For merged preprocessor, mm_data is already mm_inputs
|
||||
precomputed_mm_inputs = None
|
||||
if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
|
||||
precomputed_mm_inputs = [decoder_inputs.multi_modal_data]
|
||||
|
||||
# Apply MM mapper
|
||||
mm_inputs = None
|
||||
if len(decoder_inputs.multi_modal_data) > 0:
|
||||
mm_inputs, mm_hashes = self.mm_input_mapper_client.process_inputs(
|
||||
decoder_inputs.multi_modal_data, mm_hashes,
|
||||
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
|
||||
|
||||
# Make Request for Detokenizer.
|
||||
detokenizer_request = DetokenizerRequest(
|
||||
@@ -130,6 +140,7 @@ class Processor:
|
||||
decoder_inputs.prompt,
|
||||
decoder_inputs.prompt_token_ids,
|
||||
mm_inputs,
|
||||
mm_hashes,
|
||||
decoder_inputs.multi_modal_placeholders,
|
||||
sampling_params,
|
||||
eos_token_id,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from collections import OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generic, Iterator, List, TypeVar, overload
|
||||
|
||||
@@ -93,3 +94,23 @@ def make_zmq_socket(path: str, type: Any) -> Iterator[zmq.Socket]:
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
class LRUDictCache:
|
||||
|
||||
def __init__(self, size: int):
|
||||
self.cache = OrderedDict()
|
||||
self.size = size
|
||||
|
||||
def get(self, key, default=None):
|
||||
if key not in self.cache:
|
||||
return default
|
||||
|
||||
self.cache.move_to_end(key)
|
||||
return self.cache[key]
|
||||
|
||||
def put(self, key, value):
|
||||
self.cache[key] = value
|
||||
self.cache.move_to_end(key)
|
||||
if len(self.cache) > self.size:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
Reference in New Issue
Block a user