[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input (#11674)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final
|
from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
|
||||||
|
final)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -11,7 +12,7 @@ from PIL.Image import Image
|
|||||||
from transformers import BatchFeature
|
from transformers import BatchFeature
|
||||||
from typing_extensions import NotRequired, TypeAlias
|
from typing_extensions import NotRequired, TypeAlias
|
||||||
|
|
||||||
from vllm.utils import JSONTree, is_list_of, json_map_leaves
|
from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
|
||||||
|
|
||||||
_T = TypeVar("_T")
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
@@ -160,11 +161,8 @@ A dictionary containing nested tensors which have been batched via
|
|||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MultiModalFieldItem:
|
class MultiModalFieldElem:
|
||||||
"""
|
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
|
||||||
Contains metadata and data in :class:`MultiModalKwargs`
|
|
||||||
corresponding to a data item in :class:`MultiModalDataItems`.
|
|
||||||
"""
|
|
||||||
field: "BaseMultiModalField"
|
field: "BaseMultiModalField"
|
||||||
data: NestedTensors
|
data: NestedTensors
|
||||||
|
|
||||||
@@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
|
|||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem:
|
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
|
||||||
return MultiModalFieldItem(self, data)
|
return MultiModalFieldElem(self, data)
|
||||||
|
|
||||||
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem:
|
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
|
||||||
"""Merge multiple instances of :class:`MultiModalFieldItem` together."""
|
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
|
||||||
fields = [item.field for item in batch]
|
fields = [item.field for item in batch]
|
||||||
if len(set(fields)) > 1:
|
if len(set(fields)) > 1:
|
||||||
raise ValueError(f"Cannot merge different {fields=}")
|
raise ValueError(f"Cannot merge different {fields=}")
|
||||||
|
|
||||||
data = self._reduce_data([item.data for item in batch])
|
data = self._reduce_data([item.data for item in batch])
|
||||||
|
|
||||||
return self._build_item(data)
|
return self._build_elem(data)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MultiModalBatchedField(BaseMultiModalField):
|
class MultiModalBatchedField(BaseMultiModalField):
|
||||||
"""
|
"""
|
||||||
A :class:`BaseMultiModalField` implementation where an item is obtained by
|
A :class:`BaseMultiModalField` implementation where an element in the batch
|
||||||
directly indexing into the first dimension of the underlying data.
|
is obtained by indexing into the first dimension of the underlying data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]:
|
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
|
||||||
return [self._build_item(item) for item in batch]
|
return [self._build_elem(item) for item in batch]
|
||||||
|
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
first_shape = batch[0].shape
|
first_shape = batch[0].shape
|
||||||
if all(item.shape == first_shape for item in batch):
|
if all(elem.shape == first_shape for elem in batch):
|
||||||
return torch.stack(batch)
|
return torch.stack(batch)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
@@ -222,24 +220,24 @@ class MultiModalBatchedField(BaseMultiModalField):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MultiModalFlatField(BaseMultiModalField):
|
class MultiModalFlatField(BaseMultiModalField):
|
||||||
"""
|
"""
|
||||||
A :class:`BaseMultiModalField` implementation where an item is obtained by
|
A :class:`BaseMultiModalField` implementation where an element in the batch
|
||||||
slicing along the first dimension of the underlying data.
|
is obtained by slicing along the first dimension of the underlying data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def build_items(
|
def build_elems(
|
||||||
self,
|
self,
|
||||||
batch: NestedTensors,
|
batch: NestedTensors,
|
||||||
slices: Sequence[slice],
|
slices: Sequence[slice],
|
||||||
) -> list[MultiModalFieldItem]:
|
) -> list[MultiModalFieldElem]:
|
||||||
return [self._build_item(batch[slice_]) for slice_ in slices]
|
return [self._build_elem(batch[slice_]) for slice_ in slices]
|
||||||
|
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
first_shape = batch[0].shape
|
first_shape = batch[0].shape
|
||||||
if all(item.shape[1:] == first_shape[1:] for item in batch):
|
if all(elem.shape[1:] == first_shape[1:] for elem in batch):
|
||||||
return torch.concat(batch)
|
return torch.concat(batch)
|
||||||
|
|
||||||
return [elem for item in batch for elem in item]
|
return [e for elem in batch for e in elem]
|
||||||
|
|
||||||
|
|
||||||
class MultiModalFieldConfig:
|
class MultiModalFieldConfig:
|
||||||
@@ -267,115 +265,111 @@ class MultiModalFieldConfig:
|
|||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self._field_cls = field_cls
|
self.field_cls = field_cls
|
||||||
self._modality = modality
|
self.modality = modality
|
||||||
self._field_config = field_config
|
self.field_config = field_config
|
||||||
|
|
||||||
def build_items(
|
def build_elems(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
batch: NestedTensors,
|
batch: NestedTensors,
|
||||||
) -> list[MultiModalFieldItem]:
|
) -> Sequence[MultiModalFieldElem]:
|
||||||
field = self._field_cls(key=key, modality=self._modality)
|
field = self.field_cls(key=key, modality=self.modality)
|
||||||
return field.build_items(batch, **self._field_config) # type: ignore
|
return field.build_elems(batch, **self.field_config) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||||
|
"""
|
||||||
|
A collection of :class:`MultiModalFieldElem`
|
||||||
|
corresponding to a data item in :class:`MultiModalDataItems`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||||
|
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
|
||||||
|
|
||||||
|
@property
|
||||||
|
def modality(self) -> str:
|
||||||
|
modalities = {elem.field.modality for elem in self.data.values()}
|
||||||
|
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||||
|
return next(iter(modalities))
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: UserDict is for V0 compatibility.
|
||||||
|
# V1 should access individual items via `get_item`.
|
||||||
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
class MultiModalKwargs(UserDict[str, NestedTensors]):
|
||||||
"""
|
"""
|
||||||
A dictionary that represents the keyword arguments to
|
A dictionary that represents the keyword arguments to
|
||||||
:meth:`~torch.nn.Module.forward`.
|
:meth:`~torch.nn.Module.forward`.
|
||||||
|
|
||||||
The metadata :code:`items_by_key` defines how to split batched keyword
|
The metadata :code:`items` enables us to obtain the keyword arguments
|
||||||
arguments corresponding to each data item in :class:`MultiModalDataItems`:
|
corresponding to each data item in :class:`MultiModalDataItems`, via
|
||||||
|
:meth:`get_item` and :meth:`get_items`.
|
||||||
- For a keyword argument, we can access the :code:`i` th item in the batch
|
|
||||||
via :code:`items_by_key[key][i]`.
|
|
||||||
- We can gather the keyword arguments belonging to a modality by finding
|
|
||||||
the keys with items that belong to that modality, then accessing
|
|
||||||
the :code:`i` th item in the batch for each such key.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
# All items belong to the "image" modality
|
|
||||||
items_by_key={
|
|
||||||
"pixel_values": [a, b, c, d], # "image" modality
|
|
||||||
"image_grid_thw": [e, f, g, h], # "image" modality
|
|
||||||
"pixel_values_video": [h, i, j], # "video" modality
|
|
||||||
"video_grid_thw": [k, l, m], # "video" modality
|
|
||||||
}
|
|
||||||
|
|
||||||
- The keyword arguments belonging to the first image are
|
|
||||||
:code:`{"pixel_values": a, "image_grid_thw": e}`.
|
|
||||||
- The keyword arguments belonging to the second video are
|
|
||||||
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_hf_inputs(
|
def from_hf_inputs(
|
||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
config_by_key: Mapping[str, MultiModalFieldConfig],
|
config_by_key: Mapping[str, MultiModalFieldConfig],
|
||||||
*,
|
|
||||||
enable_sanity_checks: bool = False,
|
|
||||||
):
|
):
|
||||||
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
|
||||||
# We assume that those fields are not used in vLLM
|
# We assume that those fields are not used in vLLM
|
||||||
items_by_key = {
|
elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
|
||||||
key: config.build_items(key, batch)
|
keys_by_modality = defaultdict[str, set[str]](set)
|
||||||
for key, config in config_by_key.items()
|
for key, config in config_by_key.items():
|
||||||
if (batch := hf_inputs.get(key)) is not None
|
batch = hf_inputs.get(key)
|
||||||
}
|
if batch is not None:
|
||||||
|
elems = config.build_elems(key, batch)
|
||||||
|
if len(elems) > 0:
|
||||||
|
elems_by_key[key] = elems
|
||||||
|
keys_by_modality[config.modality].add(key)
|
||||||
|
|
||||||
return MultiModalKwargs.from_items_by_key(
|
items = list[MultiModalKwargsItem]()
|
||||||
items_by_key,
|
for modality, keys in keys_by_modality.items():
|
||||||
enable_sanity_checks=enable_sanity_checks,
|
elems_in_modality = {k: elems_by_key[k] for k in keys}
|
||||||
)
|
batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
|
||||||
|
|
||||||
|
if len(set(batch_sizes.values())) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot merge different batch sizes for {modality=}! "
|
||||||
|
f"Found: {batch_sizes=}")
|
||||||
|
|
||||||
|
batch_size = next(iter(batch_sizes.values()))
|
||||||
|
for item_idx in range(batch_size):
|
||||||
|
elems = [v[item_idx] for v in elems_in_modality.values()]
|
||||||
|
items.append(MultiModalKwargsItem.from_elems(elems))
|
||||||
|
|
||||||
|
return MultiModalKwargs.from_items(items)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_items_by_key(
|
def from_items(items: Sequence[MultiModalKwargsItem]):
|
||||||
items_by_key: Mapping[str, list[MultiModalFieldItem]],
|
"""Construct a new :class:`MultiModalKwargs` from multiple items."""
|
||||||
*,
|
elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
|
||||||
enable_sanity_checks: bool = False,
|
for item in items:
|
||||||
) -> "MultiModalKwargs":
|
for key, elem in item.items():
|
||||||
|
elems_by_key[key].append(elem)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
key: items[0].field.reduce(items).data
|
key: elems[0].field.reduce(elems).data
|
||||||
for key, items in items_by_key.items() if len(items) > 0
|
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
return MultiModalKwargs(data,
|
return MultiModalKwargs(data, items=items)
|
||||||
items_by_key=items_by_key,
|
|
||||||
enable_sanity_checks=enable_sanity_checks)
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data: Mapping[str, NestedTensors],
|
data: Mapping[str, NestedTensors],
|
||||||
*,
|
*,
|
||||||
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {},
|
items: Optional[Sequence[MultiModalKwargsItem]] = None,
|
||||||
enable_sanity_checks: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(data)
|
super().__init__(data)
|
||||||
|
|
||||||
# Shallow copy to avoid footgun in case a defaultdict is passed in
|
items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
|
||||||
self._items_by_key = dict(items_by_key)
|
self._items_by_modality = dict(items_by_modality)
|
||||||
|
|
||||||
keys_by_modality = defaultdict[str, set[str]](set)
|
@property
|
||||||
for key, items in items_by_key.items():
|
def modalities(self):
|
||||||
for item in items:
|
return self._items_by_modality.keys()
|
||||||
keys_by_modality[item.field.modality].add(key)
|
|
||||||
|
|
||||||
self._keys_by_modality = dict(keys_by_modality)
|
|
||||||
|
|
||||||
if enable_sanity_checks:
|
|
||||||
for modality, keys in keys_by_modality.items():
|
|
||||||
items_in_modality = {k: items_by_key[k] for k in keys}
|
|
||||||
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
|
|
||||||
batch_size = next(iter(batch_sizes.values()), 0)
|
|
||||||
assert all(bs == batch_size
|
|
||||||
for bs in batch_sizes.values()), dict(
|
|
||||||
modality=modality,
|
|
||||||
batch_sizes=batch_sizes,
|
|
||||||
items_by_key=items_by_key)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
|
||||||
@@ -452,58 +446,44 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
return False
|
return False
|
||||||
if self._items_by_key != other._items_by_key:
|
if self._items_by_modality != other._items_by_modality:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
ks = self.keys()
|
ks = self.keys()
|
||||||
return (ks == other.keys()
|
return (ks == other.keys()
|
||||||
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
and all(nested_tensors_equal(self[k], other[k]) for k in ks))
|
||||||
|
|
||||||
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem:
|
def _validate_modality(self, method_name: str, modality: str) -> None:
|
||||||
return self._items_by_key[key][item_index]
|
if not self._items_by_modality:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"`{method_name}` is not supported when "
|
||||||
|
"MultiModalKwargs is not initialized with `items`")
|
||||||
|
|
||||||
def get_items_by_modality(
|
if modality not in self._items_by_modality:
|
||||||
self,
|
available_modalities = set(self._items_by_modality.keys())
|
||||||
modality: str,
|
raise KeyError(f"Modality {modality!r} not found. "
|
||||||
item_index: int,
|
f"Available modalities: {available_modalities}")
|
||||||
) -> Mapping[str, MultiModalFieldItem]:
|
|
||||||
|
def get_item_count(self, modality: str) -> int:
|
||||||
|
"""Get the number of items belonging to a modality."""
|
||||||
|
self._validate_modality("get_item_count", modality)
|
||||||
|
return len(self._items_by_modality[modality])
|
||||||
|
|
||||||
|
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
|
||||||
"""
|
"""
|
||||||
Get the keyword arguments corresponding to an item identified by
|
Get the keyword arguments corresponding to an item identified by
|
||||||
its modality and index.
|
its modality and index.
|
||||||
"""
|
"""
|
||||||
if modality not in self._keys_by_modality:
|
self._validate_modality("get_item", modality)
|
||||||
available_modalities = set(self._keys_by_modality.keys())
|
return self._items_by_modality[modality][item_index]
|
||||||
raise KeyError(f"Modality {modality!r} not found. "
|
|
||||||
f"Available modalities: {available_modalities}")
|
|
||||||
|
|
||||||
keys_to_gather = self._keys_by_modality[modality]
|
def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
|
||||||
|
|
||||||
return {
|
|
||||||
key: self.get_item(key, item_index)
|
|
||||||
for key in keys_to_gather if key in self
|
|
||||||
}
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_items_by_modality(
|
|
||||||
items_by_modality: Mapping[str, list[Mapping[str,
|
|
||||||
MultiModalFieldItem]]],
|
|
||||||
*,
|
|
||||||
enable_sanity_checks: bool = False,
|
|
||||||
) -> "MultiModalKwargs":
|
|
||||||
"""
|
"""
|
||||||
Construct a new :class:`MultiModalKwargs` from multiple items returned
|
Get the keyword arguments corresponding to each item belonging to
|
||||||
by :meth:`get_fields_by_modality`.
|
a modality.
|
||||||
"""
|
"""
|
||||||
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list)
|
self._validate_modality("get_items", modality)
|
||||||
for fields in items_by_modality.values():
|
return self._items_by_modality[modality]
|
||||||
for field in fields:
|
|
||||||
for k, v in field.items():
|
|
||||||
items_by_key[k].append(v)
|
|
||||||
|
|
||||||
return MultiModalKwargs.from_items_by_key(
|
|
||||||
items_by_key,
|
|
||||||
enable_sanity_checks=enable_sanity_checks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
|
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
|||||||
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
|
||||||
|
|
||||||
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs,
|
MultiModalInputsV2, MultiModalKwargs,
|
||||||
PlaceholderRange)
|
MultiModalKwargsItem, PlaceholderRange)
|
||||||
from .parse import MultiModalDataItems, MultiModalDataParser
|
from .parse import MultiModalDataItems, MultiModalDataParser
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@@ -496,8 +496,7 @@ class ProcessingCache:
|
|||||||
# DEBUG: Set to None to disable
|
# DEBUG: Set to None to disable
|
||||||
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
self.debug_cache_hit_ratio_steps: Optional[int] = None
|
||||||
|
|
||||||
self._cache = LRUCache[str, Mapping[str,
|
self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
|
||||||
MultiModalFieldItem]](capacity)
|
|
||||||
|
|
||||||
def _maybe_log_cache_stats(self) -> None:
|
def _maybe_log_cache_stats(self) -> None:
|
||||||
steps = self.debug_cache_hit_ratio_steps
|
steps = self.debug_cache_hit_ratio_steps
|
||||||
@@ -565,7 +564,7 @@ class ProcessingCache:
|
|||||||
modality: str,
|
modality: str,
|
||||||
input_item: object,
|
input_item: object,
|
||||||
input_kwargs: Mapping[str, object],
|
input_kwargs: Mapping[str, object],
|
||||||
) -> Optional[Mapping[str, MultiModalFieldItem]]:
|
) -> Optional[MultiModalKwargsItem]:
|
||||||
"""
|
"""
|
||||||
Get a processed multi-modal item from the cache
|
Get a processed multi-modal item from the cache
|
||||||
according to its dependencies, including:
|
according to its dependencies, including:
|
||||||
@@ -588,7 +587,7 @@ class ProcessingCache:
|
|||||||
modality: str,
|
modality: str,
|
||||||
input_item: object,
|
input_item: object,
|
||||||
input_kwargs: Mapping[str, object],
|
input_kwargs: Mapping[str, object],
|
||||||
output_kwargs: Mapping[str, MultiModalFieldItem],
|
output_kwargs: MultiModalKwargsItem,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Put a processed multi-modal item into the cache
|
Put a processed multi-modal item into the cache
|
||||||
@@ -784,7 +783,6 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
mm_kwargs = MultiModalKwargs.from_hf_inputs(
|
mm_kwargs = MultiModalKwargs.from_hf_inputs(
|
||||||
processed_data,
|
processed_data,
|
||||||
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
|
||||||
enable_sanity_checks=self.enable_sanity_checks,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_ids, mm_kwargs
|
return prompt_ids, mm_kwargs
|
||||||
@@ -846,7 +844,7 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_maybe_cached_field_items = {
|
mm_maybe_cached_kw_items = {
|
||||||
modality: [
|
modality: [
|
||||||
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
|
cache.get(model_id, modality, item, hf_processor_mm_kwargs)
|
||||||
for item in items
|
for item in items
|
||||||
@@ -855,8 +853,9 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
}
|
}
|
||||||
|
|
||||||
mm_missing_idxs = {
|
mm_missing_idxs = {
|
||||||
modality: [idx for idx, out in enumerate(fields) if out is None]
|
modality:
|
||||||
for modality, fields in mm_maybe_cached_field_items.items()
|
[idx for idx, item in enumerate(kw_items) if item is None]
|
||||||
|
for modality, kw_items in mm_maybe_cached_kw_items.items()
|
||||||
}
|
}
|
||||||
mm_missing_data = {
|
mm_missing_data = {
|
||||||
modality: [mm_data_items[modality][idx] for idx in idxs]
|
modality: [mm_data_items[modality][idx] for idx in idxs]
|
||||||
@@ -875,14 +874,11 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
for modality in mm_missing_data_items
|
for modality in mm_missing_data_items
|
||||||
}
|
}
|
||||||
|
|
||||||
mm_merged_field_items = dict[str, list[Mapping[str,
|
merged_kw_items = list[MultiModalKwargsItem]()
|
||||||
MultiModalFieldItem]]]()
|
for modality, kw_items in mm_maybe_cached_kw_items.items():
|
||||||
for modality, modal_items_lst in mm_maybe_cached_field_items.items():
|
for idx, kw_item in enumerate(kw_items):
|
||||||
merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]()
|
if kw_item is None:
|
||||||
|
kw_item = mm_missing_kwargs.get_item(
|
||||||
for idx, modal_items in enumerate(modal_items_lst):
|
|
||||||
if modal_items is None:
|
|
||||||
modal_items = mm_missing_kwargs.get_items_by_modality(
|
|
||||||
modality,
|
modality,
|
||||||
mm_missing_next_idx[modality],
|
mm_missing_next_idx[modality],
|
||||||
)
|
)
|
||||||
@@ -892,14 +888,12 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
modality,
|
modality,
|
||||||
mm_data_items[modality][idx],
|
mm_data_items[modality][idx],
|
||||||
hf_processor_mm_kwargs,
|
hf_processor_mm_kwargs,
|
||||||
modal_items,
|
kw_item,
|
||||||
)
|
)
|
||||||
|
|
||||||
mm_missing_next_idx[modality] += 1
|
mm_missing_next_idx[modality] += 1
|
||||||
|
|
||||||
merged_modal_items_lst.append(modal_items)
|
merged_kw_items.append(kw_item)
|
||||||
|
|
||||||
mm_merged_field_items[modality] = merged_modal_items_lst
|
|
||||||
|
|
||||||
if self.enable_sanity_checks:
|
if self.enable_sanity_checks:
|
||||||
mm_missing_counts = mm_missing_data_items.get_all_counts()
|
mm_missing_counts = mm_missing_data_items.get_all_counts()
|
||||||
@@ -909,10 +903,7 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
mm_missing_next_idx=mm_missing_next_idx,
|
mm_missing_next_idx=mm_missing_next_idx,
|
||||||
mm_missing_counts=mm_missing_counts)
|
mm_missing_counts=mm_missing_counts)
|
||||||
|
|
||||||
mm_kwargs = MultiModalKwargs.from_items_by_modality(
|
mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
|
||||||
mm_merged_field_items,
|
|
||||||
enable_sanity_checks=self.enable_sanity_checks,
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.enable_sanity_checks:
|
if self.enable_sanity_checks:
|
||||||
mm_item_counts = mm_data_items.get_all_counts()
|
mm_item_counts = mm_data_items.get_all_counts()
|
||||||
@@ -920,7 +911,7 @@ class BaseMultiModalProcessor(ABC):
|
|||||||
for modality, item_count in mm_item_counts.items():
|
for modality, item_count in mm_item_counts.items():
|
||||||
for item_idx in range(item_count):
|
for item_idx in range(item_count):
|
||||||
try:
|
try:
|
||||||
mm_kwargs.get_items_by_modality(modality, item_idx)
|
mm_kwargs.get_item(modality, item_idx)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Make it easy to set a breakpoint in the debugger
|
# Make it easy to set a breakpoint in the debugger
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -113,15 +113,27 @@ class Processor:
|
|||||||
|
|
||||||
# For merged preprocessor, mm_data is already mm_inputs
|
# For merged preprocessor, mm_data is already mm_inputs
|
||||||
precomputed_mm_inputs = None
|
precomputed_mm_inputs = None
|
||||||
if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs):
|
decoder_mm_data = decoder_inputs.multi_modal_data
|
||||||
precomputed_mm_inputs = [decoder_inputs.multi_modal_data]
|
if isinstance(decoder_mm_data, MultiModalKwargs):
|
||||||
|
# The output of merged multi-modal processor (`decoder_mm_data`)
|
||||||
|
# contains the kwargs for all items from all modalities.
|
||||||
|
# This code separates them so that there is one set of kwargs
|
||||||
|
# per item per modality.
|
||||||
|
precomputed_mm_inputs = [
|
||||||
|
MultiModalKwargs.from_items([item])
|
||||||
|
for modality in decoder_mm_data.modalities
|
||||||
|
for item in decoder_mm_data.get_items(modality)
|
||||||
|
]
|
||||||
|
|
||||||
# Apply MM mapper
|
# Apply MM mapper
|
||||||
mm_inputs = None
|
mm_inputs = None
|
||||||
if len(decoder_inputs.multi_modal_data) > 0:
|
if len(decoder_mm_data) > 0:
|
||||||
mm_inputs = self.mm_input_mapper_client.process_inputs(
|
mm_inputs = self.mm_input_mapper_client.process_inputs(
|
||||||
decoder_inputs.multi_modal_data, mm_hashes,
|
decoder_mm_data,
|
||||||
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs)
|
mm_hashes,
|
||||||
|
decoder_inputs.mm_processor_kwargs,
|
||||||
|
precomputed_mm_inputs,
|
||||||
|
)
|
||||||
|
|
||||||
return EngineCoreRequest(
|
return EngineCoreRequest(
|
||||||
request_id,
|
request_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user