[Refactor] Allow optional MultiModalKwargsItem in IPC (#23022)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
@@ -47,7 +48,7 @@ class EngineCoreRequest(
|
||||
|
||||
request_id: str
|
||||
prompt_token_ids: list[int]
|
||||
mm_kwargs: Optional[list[MultiModalKwargsItem]]
|
||||
mm_kwargs: Optional[Sequence[Optional[MultiModalKwargsItem]]]
|
||||
mm_hashes: Optional[list[str]]
|
||||
mm_placeholders: Optional[list[PlaceholderRange]]
|
||||
sampling_params: Optional[SamplingParams]
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.multimodal import MultiModalRegistry
|
||||
from vllm.multimodal.cache import MultiModalCache, MultiModalCacheItemMetadata
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem, NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@@ -58,21 +59,21 @@ class MultiModalInputCacheClient:
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
mm_kwargs: Sequence[MultiModalKwargsItem],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargsItem]:
|
||||
) -> list[Optional[MultiModalKwargsItem]]:
|
||||
if not self.enabled:
|
||||
return mm_kwargs
|
||||
return list(mm_kwargs)
|
||||
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[MultiModalKwargsItem]()
|
||||
out_mm_items = list[Optional[MultiModalKwargsItem]]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if self.mm_cache.get(mm_hash) is not None:
|
||||
out_mm_items.append(mm_item.without_data())
|
||||
out_mm_items.append(None)
|
||||
else:
|
||||
self.mm_cache[mm_hash] = \
|
||||
MultiModalCacheItemMetadata.wraps(mm_item.require_data())
|
||||
MultiModalCacheItemMetadata.wraps(mm_item)
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
return out_mm_items
|
||||
@@ -91,25 +92,27 @@ class MultiModalInputCacheServer:
|
||||
self.enabled = mm_registry.enable_mm_input_cache(model_config)
|
||||
self.mm_cache = MultiModalCache.get_lru_cache(
|
||||
model_config.get_mm_input_cache_gb(),
|
||||
Mapping[str, NestedTensors],
|
||||
MultiModalKwargsItem,
|
||||
)
|
||||
|
||||
def get_and_update(
|
||||
self,
|
||||
mm_kwargs: list[MultiModalKwargsItem],
|
||||
mm_kwargs: Sequence[Optional[MultiModalKwargsItem]],
|
||||
mm_hashes: list[str],
|
||||
) -> list[MultiModalKwargsItem]:
|
||||
if not self.enabled:
|
||||
return mm_kwargs
|
||||
mm_kwargs_lst = list(mm_kwargs)
|
||||
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem)
|
||||
return mm_kwargs_lst
|
||||
|
||||
assert len(mm_kwargs) == len(mm_hashes)
|
||||
|
||||
out_mm_items = list[MultiModalKwargsItem]()
|
||||
for mm_item, mm_hash in zip(mm_kwargs, mm_hashes):
|
||||
if (mm_data := mm_item.get_data()) is None:
|
||||
out_mm_items.append(mm_item.with_data(self.mm_cache[mm_hash]))
|
||||
if mm_item is None:
|
||||
out_mm_items.append(self.mm_cache[mm_hash])
|
||||
else:
|
||||
self.mm_cache[mm_hash] = mm_data
|
||||
self.mm_cache[mm_hash] = mm_item
|
||||
out_mm_items.append(mm_item)
|
||||
|
||||
return out_mm_items
|
||||
|
||||
@@ -17,6 +17,7 @@ from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MultiModalInputCacheClient
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
@@ -295,7 +296,7 @@ class Processor:
|
||||
pooling_params = params.clone()
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[list[MultiModalKwargsItem]] = None
|
||||
sorted_mm_inputs: Optional[list[Optional[MultiModalKwargsItem]]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
@@ -308,7 +309,7 @@ class Processor:
|
||||
# in the input sequence.
|
||||
sorted_mm_idxs = argsort_mm_positions(decoder_mm_positions)
|
||||
|
||||
sorted_mm_inputs = [
|
||||
orig_sorted_mm_inputs = [
|
||||
decoder_mm_inputs.get_item(modality, idx)
|
||||
for modality, idx in sorted_mm_idxs
|
||||
]
|
||||
@@ -323,9 +324,12 @@ class Processor:
|
||||
|
||||
if sorted_mm_hashes is not None:
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update(
|
||||
sorted_mm_inputs,
|
||||
orig_sorted_mm_inputs,
|
||||
sorted_mm_hashes,
|
||||
)
|
||||
else:
|
||||
assert is_list_of(orig_sorted_mm_inputs, MultiModalKwargsItem)
|
||||
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
|
||||
@@ -125,14 +125,17 @@ class Request:
|
||||
block_hasher: Optional[Callable[["Request"], list["BlockHash"]]]
|
||||
) -> "Request":
|
||||
if request.mm_kwargs is not None:
|
||||
assert is_list_of(request.mm_kwargs, MultiModalKwargsItem), (
|
||||
mm_kwargs_lst = list(request.mm_kwargs)
|
||||
assert is_list_of(mm_kwargs_lst, MultiModalKwargsItem), (
|
||||
"mm_kwargs was not updated in EngineCore.add_request")
|
||||
else:
|
||||
mm_kwargs_lst = None
|
||||
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
client_index=request.client_index,
|
||||
prompt_token_ids=request.prompt_token_ids,
|
||||
multi_modal_kwargs=request.mm_kwargs,
|
||||
multi_modal_kwargs=mm_kwargs_lst,
|
||||
multi_modal_hashes=request.mm_hashes,
|
||||
multi_modal_placeholders=request.mm_placeholders,
|
||||
sampling_params=request.sampling_params,
|
||||
|
||||
@@ -500,8 +500,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
second_per_grid_ts = []
|
||||
audio_feature_lengths = []
|
||||
use_audio_in_video = False
|
||||
for item in self.requests[req_id].mm_kwargs:
|
||||
mm_input = item.require_data()
|
||||
for mm_item in self.requests[req_id].mm_kwargs:
|
||||
mm_input = mm_item.get_data()
|
||||
if mm_input.get("image_grid_thw") is not None:
|
||||
image_grid_thw.append(
|
||||
mm_input["image_grid_thw"].tolist())
|
||||
|
||||
Reference in New Issue
Block a user