[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (#16273)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import enum
|
import enum
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Sequence
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import msgspec
|
import msgspec
|
||||||
@@ -52,7 +53,7 @@ class EngineCoreRequest(
|
|||||||
# Detokenizer, but set to None when it is added to EngineCoreClient.
|
# Detokenizer, but set to None when it is added to EngineCoreClient.
|
||||||
prompt: Optional[str]
|
prompt: Optional[str]
|
||||||
prompt_token_ids: list[int]
|
prompt_token_ids: list[int]
|
||||||
mm_inputs: Optional[list[MultiModalKwargs]]
|
mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]]
|
||||||
mm_hashes: Optional[list[str]]
|
mm_hashes: Optional[list[str]]
|
||||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||||
sampling_params: SamplingParams
|
sampling_params: SamplingParams
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||||
EngineCoreRequestType, UtilityOutput)
|
EngineCoreRequestType, UtilityOutput)
|
||||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
@@ -105,7 +105,7 @@ class EngineCore:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Setup MM Input Mapper.
|
# Setup MM Input Mapper.
|
||||||
self.mm_input_cache_server = MMInputCacheServer(
|
self.mm_input_cache_server = MirroredProcessingCache(
|
||||||
vllm_config.model_config)
|
vllm_config.model_config)
|
||||||
|
|
||||||
# Setup batch queue for pipeline parallelism.
|
# Setup batch queue for pipeline parallelism.
|
||||||
@@ -173,7 +173,7 @@ class EngineCore:
|
|||||||
# anything that has a hash must have a HIT cache entry here
|
# anything that has a hash must have a HIT cache entry here
|
||||||
# as well.
|
# as well.
|
||||||
assert request.mm_inputs is not None
|
assert request.mm_inputs is not None
|
||||||
request.mm_inputs = self.mm_input_cache_server.get_and_update(
|
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
|
||||||
request.mm_inputs, request.mm_hashes)
|
request.mm_inputs, request.mm_hashes)
|
||||||
|
|
||||||
req = Request.from_engine_core_request(request)
|
req = Request.from_engine_core_request(request)
|
||||||
|
|||||||
@@ -1,8 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
from vllm.envs import VLLM_MM_INPUT_CACHE_GIB
|
||||||
from vllm.multimodal import MultiModalKwargs
|
from vllm.multimodal import MultiModalKwargs
|
||||||
from vllm.multimodal.processing import ProcessingCache
|
from vllm.multimodal.processing import ProcessingCache
|
||||||
|
from vllm.utils import is_list_of
|
||||||
|
|
||||||
# The idea of multimodal preprocessing caching is based on having a client and
|
# The idea of multimodal preprocessing caching is based on having a client and
|
||||||
# a server, where the client executes in the frontend process (=P0) and the
|
# a server, where the client executes in the frontend process (=P0) and the
|
||||||
@@ -11,9 +14,11 @@ from vllm.multimodal.processing import ProcessingCache
|
|||||||
# -- Client:
|
# -- Client:
|
||||||
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
|
# - BaseMultiModalProcessor to process MultiModalData into MultiModalKwargs
|
||||||
# with built-in caching functionality, with mm_hash as its identifier.
|
# with built-in caching functionality, with mm_hash as its identifier.
|
||||||
|
# - MirroredProcessingCache to keep track of the cached entries and
|
||||||
|
# determine whether to send the MultiModalKwargs to P1.
|
||||||
#
|
#
|
||||||
# -- Server:
|
# -- Server:
|
||||||
# - MMInputCacheServer to perform caching of the received MultiModalKwargs.
|
# - MirroredProcessingCache to store the MultiModalKwargs from P0.
|
||||||
#
|
#
|
||||||
# The caching for both client and server is mirrored, and this allows us
|
# The caching for both client and server is mirrored, and this allows us
|
||||||
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
# to avoid the serialization of "mm_inputs" (like pixel values) between
|
||||||
@@ -25,26 +30,48 @@ from vllm.multimodal.processing import ProcessingCache
|
|||||||
# variable VLLM_MM_INPUT_CACHE_GIB.
|
# variable VLLM_MM_INPUT_CACHE_GIB.
|
||||||
|
|
||||||
|
|
||||||
class MMInputCacheServer:
|
class MirroredProcessingCache:
|
||||||
|
|
||||||
def __init__(self, model_config):
|
def __init__(self, model_config):
|
||||||
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
self.use_cache = not model_config.disable_mm_preprocessor_cache
|
||||||
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
self.mm_cache = ProcessingCache.get_lru_cache(VLLM_MM_INPUT_CACHE_GIB,
|
||||||
MultiModalKwargs)
|
MultiModalKwargs)
|
||||||
|
|
||||||
def get_and_update(
|
def get_and_update_p0(
|
||||||
self,
|
self,
|
||||||
mm_inputs: list[MultiModalKwargs],
|
mm_inputs: Sequence[MultiModalKwargs],
|
||||||
mm_hashes: list[str],
|
mm_hashes: list[str],
|
||||||
) -> list[MultiModalKwargs]:
|
) -> Sequence[Optional[MultiModalKwargs]]:
|
||||||
assert len(mm_inputs) == len(mm_hashes)
|
assert len(mm_inputs) == len(mm_hashes)
|
||||||
|
|
||||||
if not self.use_cache:
|
if not self.use_cache:
|
||||||
|
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||||
return mm_inputs
|
return mm_inputs
|
||||||
|
|
||||||
full_mm_inputs = []
|
full_mm_inputs = list[Optional[MultiModalKwargs]]()
|
||||||
|
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||||
|
if mm_hash in self.mm_cache:
|
||||||
|
mm_input = None
|
||||||
|
else:
|
||||||
|
self.mm_cache[mm_hash] = mm_input
|
||||||
|
|
||||||
|
full_mm_inputs.append(mm_input)
|
||||||
|
|
||||||
|
return full_mm_inputs
|
||||||
|
|
||||||
|
def get_and_update_p1(
|
||||||
|
self,
|
||||||
|
mm_inputs: Sequence[Optional[MultiModalKwargs]],
|
||||||
|
mm_hashes: list[str],
|
||||||
|
) -> Sequence[MultiModalKwargs]:
|
||||||
|
assert len(mm_inputs) == len(mm_hashes)
|
||||||
|
|
||||||
|
if not self.use_cache:
|
||||||
|
assert is_list_of(mm_inputs, MultiModalKwargs)
|
||||||
|
return mm_inputs
|
||||||
|
|
||||||
|
full_mm_inputs = list[MultiModalKwargs]()
|
||||||
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
for mm_input, mm_hash in zip(mm_inputs, mm_hashes):
|
||||||
assert mm_hash is not None
|
|
||||||
if mm_input is None:
|
if mm_input is None:
|
||||||
mm_input = self.mm_cache[mm_hash]
|
mm_input = self.mm_cache[mm_hash]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
@@ -19,6 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
|||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||||
from vllm.v1.engine import EngineCoreRequest
|
from vllm.v1.engine import EngineCoreRequest
|
||||||
|
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||||
from vllm.v1.structured_output.backend_guidance import (
|
from vllm.v1.structured_output.backend_guidance import (
|
||||||
validate_guidance_grammar)
|
validate_guidance_grammar)
|
||||||
from vllm.v1.structured_output.utils import (
|
from vllm.v1.structured_output.utils import (
|
||||||
@@ -47,6 +48,8 @@ class Processor:
|
|||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
mm_registry)
|
mm_registry)
|
||||||
|
|
||||||
|
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
|
||||||
|
|
||||||
# Multi-modal hasher (for images)
|
# Multi-modal hasher (for images)
|
||||||
self.use_hash = (
|
self.use_hash = (
|
||||||
not self.model_config.disable_mm_preprocessor_cache) or \
|
not self.model_config.disable_mm_preprocessor_cache) or \
|
||||||
@@ -231,7 +234,7 @@ class Processor:
|
|||||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||||
|
|
||||||
# Multimodal related.
|
# Multimodal related.
|
||||||
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
|
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||||
sorted_mm_hashes: Optional[list[str]] = None
|
sorted_mm_hashes: Optional[list[str]] = None
|
||||||
if decoder_inputs["type"] == "multimodal":
|
if decoder_inputs["type"] == "multimodal":
|
||||||
@@ -256,20 +259,28 @@ class Processor:
|
|||||||
# are multiple modalities.
|
# are multiple modalities.
|
||||||
unique_modalities = set(sorted_item_modalities)
|
unique_modalities = set(sorted_item_modalities)
|
||||||
if len(unique_modalities) > 1:
|
if len(unique_modalities) > 1:
|
||||||
sorted_mm_inputs = []
|
orig_sorted_mm_inputs = []
|
||||||
used_indices = {modality: 0 for modality in unique_modalities}
|
used_indices = {modality: 0 for modality in unique_modalities}
|
||||||
|
|
||||||
for modality in sorted_item_modalities:
|
for modality in sorted_item_modalities:
|
||||||
items = decoder_mm_inputs.get_items(modality)
|
items = decoder_mm_inputs.get_items(modality)
|
||||||
item = items[used_indices[modality]]
|
item = items[used_indices[modality]]
|
||||||
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
|
|
||||||
]))
|
orig_sorted_mm_inputs.append(
|
||||||
|
MultiModalKwargs.from_items([item]))
|
||||||
used_indices[modality] += 1
|
used_indices[modality] += 1
|
||||||
else:
|
else:
|
||||||
sorted_mm_inputs = [
|
orig_sorted_mm_inputs = [
|
||||||
MultiModalKwargs.from_items([item]) for item in
|
MultiModalKwargs.from_items([item]) for item in
|
||||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if sorted_mm_hashes is not None:
|
||||||
|
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
|
||||||
|
orig_sorted_mm_inputs, sorted_mm_hashes)
|
||||||
|
else:
|
||||||
|
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||||
|
|
||||||
return EngineCoreRequest(
|
return EngineCoreRequest(
|
||||||
request_id=request_id,
|
request_id=request_id,
|
||||||
prompt=decoder_inputs.get("prompt"),
|
prompt=decoder_inputs.get("prompt"),
|
||||||
|
|||||||
@@ -3,17 +3,16 @@
|
|||||||
import enum
|
import enum
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.utils import is_list_of
|
||||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||||
EngineCoreRequest, FinishReason)
|
EngineCoreRequest, FinishReason)
|
||||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||||
from vllm.v1.utils import ConstantList
|
from vllm.v1.utils import ConstantList
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.multimodal import MultiModalKwargs
|
|
||||||
from vllm.multimodal.inputs import PlaceholderRange
|
|
||||||
|
|
||||||
|
|
||||||
class Request:
|
class Request:
|
||||||
@@ -23,9 +22,9 @@ class Request:
|
|||||||
request_id: str,
|
request_id: str,
|
||||||
prompt: Optional[str],
|
prompt: Optional[str],
|
||||||
prompt_token_ids: list[int],
|
prompt_token_ids: list[int],
|
||||||
multi_modal_inputs: Optional[list["MultiModalKwargs"]],
|
multi_modal_inputs: Optional[list[MultiModalKwargs]],
|
||||||
multi_modal_hashes: Optional[list[str]],
|
multi_modal_hashes: Optional[list[str]],
|
||||||
multi_modal_placeholders: Optional[list["PlaceholderRange"]],
|
multi_modal_placeholders: Optional[list[PlaceholderRange]],
|
||||||
sampling_params: SamplingParams,
|
sampling_params: SamplingParams,
|
||||||
eos_token_id: Optional[int],
|
eos_token_id: Optional[int],
|
||||||
arrival_time: float,
|
arrival_time: float,
|
||||||
@@ -75,6 +74,11 @@ class Request:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
def from_engine_core_request(cls, request: EngineCoreRequest) -> "Request":
|
||||||
|
if request.mm_inputs is not None:
|
||||||
|
assert isinstance(request.mm_inputs, list)
|
||||||
|
assert is_list_of(request.mm_inputs, MultiModalKwargs), (
|
||||||
|
"mm_inputs was not updated in EngineCore.add_request")
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
request_id=request.request_id,
|
request_id=request.request_id,
|
||||||
prompt=request.prompt,
|
prompt=request.prompt,
|
||||||
|
|||||||
Reference in New Issue
Block a user