[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (#16273)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-09 15:51:27 +08:00
committed by GitHub
parent 24f6b9a713
commit e484e02857
5 changed files with 65 additions and 22 deletions

View File

@@ -2,6 +2,7 @@
import enum import enum
import time import time
from collections.abc import Sequence
from typing import Any, Optional, Union from typing import Any, Optional, Union
import msgspec import msgspec
@@ -52,7 +53,7 @@ class EngineCoreRequest(
# Detokenizer, but set to None when it is added to EngineCoreClient. # Detokenizer, but set to None when it is added to EngineCoreClient.
prompt: Optional[str] prompt: Optional[str]
prompt_token_ids: list[int] prompt_token_ids: list[int]
mm_inputs: Optional[list[MultiModalKwargs]] mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
mm_hashes: Optional[list[str]] mm_hashes: Optional[list[str]]
mm_placeholders: Optional[list[PlaceholderRange]] mm_placeholders: Optional[list[PlaceholderRange]]
sampling_params: SamplingParams sampling_params: SamplingParams

View File

@@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest, from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput) EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.executor.abstract import Executor from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
@@ -105,7 +105,7 @@ class EngineCore:
) )
# Setup MM Input Mapper. # Setup MM Input Mapper.
self.mm_input_cache_server = MMInputCacheServer( self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config) vllm_config.model_config)
# Setup batch queue for pipeline parallelism. # Setup batch queue for pipeline parallelism.
@@ -173,7 +173,7 @@ class EngineCore:
# anything that has a hash must have a HIT cache entry here # anything that has a hash must have a HIT cache entry here
# as well. # as well.
assert request.mm_inputs is not None assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update( request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes) request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request) req = Request.from_engine_core_request(request)

View File

@@ -1,8 +1,11 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Optional
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.processing import ProcessingCache from vllm.multimodal.processing import ProcessingCache
from vllm.utils import is_list_of
# The idea of multimodal preprocessing caching is based on having a client and # The idea of multimodal preprocessing caching is based on having a client and
# a server, where the client executes in the frontend process (=P0) and the # a server, where the client executes in the frontend process (=P0) and the
@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
# -- Client: # -- Client:
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs # - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
# with built-in caching functionality, with mm_hash as its identifier. # with built-in caching functionality, with mm_hash as its identifier.
# - MirroredProcessingCache to keep track of the cached entries and
# determine whether to send the MultiModalKwargs to P1.
# #
# -- Server: # -- Server:
# - MMInputCacheServer to perform caching of the received MultiModalKwargs. # - MirroredProcessingCache to store the MultiModalKwargs from P0.
# #
# The caching for both client and server is mirrored, and this allows us # The caching for both client and server is mirrored, and this allows us
# to avoid the serialization of "mm_inputs" (like pixel values) between # to avoid the serialization of "mm_inputs" (like pixel values) between
@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
# variable VLLM_MM_INPUT_CACHE_GIB. # variable VLLM_MM_INPUT_CACHE_GIB.
class MMInputCacheServer: class MirroredProcessingCache:
def __init__(self, model_config): def __init__(self, model_config):
self.use_cache = not model_config.disable_mm_preprocessor_cache self.use_cache = not model_config.disable_mm_preprocessor_cache
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB, self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
MultiModalKwargs) MultiModalKwargs)
def get_and_update( def get_and_update_p0(
self, self,
mm_inputs: list[MultiModalKwargs], mm_inputs: Sequence[MultiModalKwargs],
mm_hashes: list[str], mm_hashes: list[str],
) -> list[MultiModalKwargs]: ) -> Sequence[Optional[MultiModalKwargs]]:
assert len(mm_inputs) == len(mm_hashes) assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache: if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs return mm_inputs
full_mm_inputs = [] full_mm_inputs = list[Optional[MultiModalKwargs]]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
if mm_hash in self.mm_cache:
mm_input = None
else:
self.mm_cache[mm_hash] = mm_input
full_mm_inputs.append(mm_input)
return full_mm_inputs
def get_and_update_p1(
self,
mm_inputs: Sequence[Optional[MultiModalKwargs]],
mm_hashes: list[str],
) -> Sequence[MultiModalKwargs]:
assert len(mm_inputs) == len(mm_hashes)
if not self.use_cache:
assert is_list_of(mm_inputs, MultiModalKwargs)
return mm_inputs
full_mm_inputs = list[MultiModalKwargs]()
for mm_input, mm_hash in zip(mm_inputs, mm_hashes): for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
assert mm_hash is not None
if mm_input is None: if mm_input is None:
mm_input = self.mm_cache[mm_hash] mm_input = self.mm_cache[mm_hash]
else: else:

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import time import time
from collections.abc import Mapping from collections.abc import Mapping, Sequence
from typing import Literal, Optional, Union from typing import Literal, Optional, Union
from vllm.config import VllmConfig from vllm.config import VllmConfig
@@ -19,6 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import ( from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar) validate_guidance_grammar)
from vllm.v1.structured_output.utils import ( from vllm.v1.structured_output.utils import (
@@ -47,6 +48,8 @@ class Processor:
self.tokenizer, self.tokenizer,
mm_registry) mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images) # Multi-modal hasher (for images)
self.use_hash = ( self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \ not self.model_config.disable_mm_preprocessor_cache) or \
@@ -231,7 +234,7 @@ class Processor:
self.tokenizer.get_lora_tokenizer(lora_request)) self.tokenizer.get_lora_tokenizer(lora_request))
# Multimodal related. # Multimodal related.
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal": if decoder_inputs["type"] == "multimodal":
@@ -256,20 +259,28 @@ class Processor:
# are multiple modalities. # are multiple modalities.
unique_modalities = set(sorted_item_modalities) unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1: if len(unique_modalities) > 1:
sorted_mm_inputs = [] orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities} used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities: for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality) items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]] item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
])) orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1 used_indices[modality] += 1
else: else:
sorted_mm_inputs = [ orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0]) decoder_mm_inputs.get_items(sorted_item_modalities[0])
] ]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return EngineCoreRequest( return EngineCoreRequest(
request_id=request_id, request_id=request_id,
prompt=decoder_inputs.get("prompt"), prompt=decoder_inputs.get("prompt"),

View File

@@ -3,17 +3,16 @@
import enum import enum
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.utils import is_list_of
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType, from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
EngineCoreRequest, FinishReason) EngineCoreRequest, FinishReason)
from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.structured_output.request import StructuredOutputRequest
from vllm.v1.utils import ConstantList from vllm.v1.utils import ConstantList
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.inputs import PlaceholderRange
class Request: class Request:
@@ -23,9 +22,9 @@ class Request:
request_id: str, request_id: str,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: list[int], prompt_token_ids: list[int],
multi_modal_inputs: Optional[list["MultiModalKwargs"]], multi_modal_inputs: Optional[list[MultiModalKwargs]],
multi_modal_hashes: Optional[list[str]], multi_modal_hashes: Optional[list[str]],
multi_modal_placeholders: Optional[list["PlaceholderRange"]], multi_modal_placeholders: Optional[list[PlaceholderRange]],
sampling_params: SamplingParams, sampling_params: SamplingParams,
eos_token_id: Optional[int], eos_token_id: Optional[int],
arrival_time: float, arrival_time: float,
@@ -75,6 +74,11 @@ class Request:
@classmethod @classmethod
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request": def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
if request.mm_inputs is not None:
assert isinstance(request.mm_inputs, list)
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
"mm_inputs was not updated in EngineCore.add_request")
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
prompt=request.prompt, prompt=request.prompt,