[VLM][Bugfix] Multi-modal processor compatible with V1 multi-input (#11674)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-02 17:00:00 +08:00
committed by GitHub
parent a115ac46b5
commit 23c1b10a4c
3 changed files with 151 additions and 168 deletions

View File

@@ -2,7 +2,8 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Literal, TypedDict, TypeVar, Union, cast, final from typing import (Any, Literal, Optional, TypedDict, TypeVar, Union, cast,
final)
import numpy as np import numpy as np
import torch import torch
@@ -11,7 +12,7 @@ from PIL.Image import Image
from transformers import BatchFeature from transformers import BatchFeature
from typing_extensions import NotRequired, TypeAlias from typing_extensions import NotRequired, TypeAlias
from vllm.utils import JSONTree, is_list_of, json_map_leaves from vllm.utils import JSONTree, full_groupby, is_list_of, json_map_leaves
_T = TypeVar("_T") _T = TypeVar("_T")
@@ -160,11 +161,8 @@ A dictionary containing nested tensors which have been batched via
@dataclass(frozen=True) @dataclass(frozen=True)
class MultiModalFieldItem: class MultiModalFieldElem:
""" """Contains metadata and data of an item in :class:`MultiModalKwargs`."""
Contains metadata and data in :class:`MultiModalKwargs`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
field: "BaseMultiModalField" field: "BaseMultiModalField"
data: NestedTensors data: NestedTensors
@@ -186,34 +184,34 @@ class BaseMultiModalField(ABC):
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError raise NotImplementedError
def _build_item(self, data: NestedTensors) -> MultiModalFieldItem: def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
return MultiModalFieldItem(self, data) return MultiModalFieldElem(self, data)
def reduce(self, batch: list[MultiModalFieldItem]) -> MultiModalFieldItem: def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
"""Merge multiple instances of :class:`MultiModalFieldItem` together.""" """Merge multiple instances of :class:`MultiModalFieldElem` together."""
fields = [item.field for item in batch] fields = [item.field for item in batch]
if len(set(fields)) > 1: if len(set(fields)) > 1:
raise ValueError(f"Cannot merge different {fields=}") raise ValueError(f"Cannot merge different {fields=}")
data = self._reduce_data([item.data for item in batch]) data = self._reduce_data([item.data for item in batch])
return self._build_item(data) return self._build_elem(data)
@dataclass(frozen=True) @dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField): class MultiModalBatchedField(BaseMultiModalField):
""" """
A :class:`BaseMultiModalField` implementation where an item is obtained by A :class:`BaseMultiModalField` implementation where an element in the batch
directly indexing into the first dimension of the underlying data. is obtained by indexing into the first dimension of the underlying data.
""" """
def build_items(self, batch: NestedTensors) -> list[MultiModalFieldItem]: def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
return [self._build_item(item) for item in batch] return [self._build_elem(item) for item in batch]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape first_shape = batch[0].shape
if all(item.shape == first_shape for item in batch): if all(elem.shape == first_shape for elem in batch):
return torch.stack(batch) return torch.stack(batch)
return batch return batch
@@ -222,24 +220,24 @@ class MultiModalBatchedField(BaseMultiModalField):
@dataclass(frozen=True) @dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField): class MultiModalFlatField(BaseMultiModalField):
""" """
A :class:`BaseMultiModalField` implementation where an item is obtained by A :class:`BaseMultiModalField` implementation where an element in the batch
slicing along the first dimension of the underlying data. is obtained by slicing along the first dimension of the underlying data.
""" """
def build_items( def build_elems(
self, self,
batch: NestedTensors, batch: NestedTensors,
slices: Sequence[slice], slices: Sequence[slice],
) -> list[MultiModalFieldItem]: ) -> list[MultiModalFieldElem]:
return [self._build_item(batch[slice_]) for slice_ in slices] return [self._build_elem(batch[slice_]) for slice_ in slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
first_shape = batch[0].shape first_shape = batch[0].shape
if all(item.shape[1:] == first_shape[1:] for item in batch): if all(elem.shape[1:] == first_shape[1:] for elem in batch):
return torch.concat(batch) return torch.concat(batch)
return [elem for item in batch for elem in item] return [e for elem in batch for e in elem]
class MultiModalFieldConfig: class MultiModalFieldConfig:
@@ -267,115 +265,111 @@ class MultiModalFieldConfig:
) -> None: ) -> None:
super().__init__() super().__init__()
self._field_cls = field_cls self.field_cls = field_cls
self._modality = modality self.modality = modality
self._field_config = field_config self.field_config = field_config
def build_items( def build_elems(
self, self,
key: str, key: str,
batch: NestedTensors, batch: NestedTensors,
) -> list[MultiModalFieldItem]: ) -> Sequence[MultiModalFieldElem]:
field = self._field_cls(key=key, modality=self._modality) field = self.field_cls(key=key, modality=self.modality)
return field.build_items(batch, **self._field_config) # type: ignore return field.build_elems(batch, **self.field_config) # type: ignore
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
"""
A collection of :class:`MultiModalFieldElem`
corresponding to a data item in :class:`MultiModalDataItems`.
"""
@staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
@property
def modality(self) -> str:
modalities = {elem.field.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities))
# NOTE: UserDict is for V0 compatibility.
# V1 should access individual items via `get_item`.
class MultiModalKwargs(UserDict[str, NestedTensors]): class MultiModalKwargs(UserDict[str, NestedTensors]):
""" """
A dictionary that represents the keyword arguments to A dictionary that represents the keyword arguments to
:meth:`~torch.nn.Module.forward`. :meth:`~torch.nn.Module.forward`.
The metadata :code:`items_by_key` defines how to split batched keyword The metadata :code:`items` enables us to obtain the keyword arguments
arguments corresponding to each data item in :class:`MultiModalDataItems`: corresponding to each data item in :class:`MultiModalDataItems`, via
:meth:`get_item` and :meth:`get_items`.
- For a keyword argument, we can access the :code:`i` th item in the batch
via :code:`items_by_key[key][i]`.
- We can gather the keyword arguments belonging to a modality by finding
the keys with items that belong to that modality, then accessing
the :code:`i` th item in the batch for each such key.
Example:
.. code-block:: python
# All items belong to the "image" modality
items_by_key={
"pixel_values": [a, b, c, d], # "image" modality
"image_grid_thw": [e, f, g, h], # "image" modality
"pixel_values_video": [h, i, j], # "video" modality
"video_grid_thw": [k, l, m], # "video" modality
}
- The keyword arguments belonging to the first image are
:code:`{"pixel_values": a, "image_grid_thw": e}`.
- The keyword arguments belonging to the second video are
:code:`{"pixel_values_video": i, "video_grid_thw": l}`.
""" """
@staticmethod @staticmethod
def from_hf_inputs( def from_hf_inputs(
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
config_by_key: Mapping[str, MultiModalFieldConfig], config_by_key: Mapping[str, MultiModalFieldConfig],
*,
enable_sanity_checks: bool = False,
): ):
# NOTE: This skips fields in `hf_inputs` that are not in `config_by_key` # NOTE: This skips fields in `hf_inputs` that are not in `config_by_key`
# We assume that those fields are not used in vLLM # We assume that those fields are not used in vLLM
items_by_key = { elems_by_key = dict[str, Sequence[MultiModalFieldElem]]()
key: config.build_items(key, batch) keys_by_modality = defaultdict[str, set[str]](set)
for key, config in config_by_key.items() for key, config in config_by_key.items():
if (batch := hf_inputs.get(key)) is not None batch = hf_inputs.get(key)
} if batch is not None:
elems = config.build_elems(key, batch)
if len(elems) > 0:
elems_by_key[key] = elems
keys_by_modality[config.modality].add(key)
return MultiModalKwargs.from_items_by_key( items = list[MultiModalKwargsItem]()
items_by_key, for modality, keys in keys_by_modality.items():
enable_sanity_checks=enable_sanity_checks, elems_in_modality = {k: elems_by_key[k] for k in keys}
) batch_sizes = {k: len(v) for k, v in elems_in_modality.items()}
if len(set(batch_sizes.values())) > 1:
raise ValueError(
f"Cannot merge different batch sizes for {modality=}! "
f"Found: {batch_sizes=}")
batch_size = next(iter(batch_sizes.values()))
for item_idx in range(batch_size):
elems = [v[item_idx] for v in elems_in_modality.values()]
items.append(MultiModalKwargsItem.from_elems(elems))
return MultiModalKwargs.from_items(items)
@staticmethod @staticmethod
def from_items_by_key( def from_items(items: Sequence[MultiModalKwargsItem]):
items_by_key: Mapping[str, list[MultiModalFieldItem]], """Construct a new :class:`MultiModalKwargs` from multiple items."""
*, elems_by_key = defaultdict[str, list[MultiModalFieldElem]](list)
enable_sanity_checks: bool = False, for item in items:
) -> "MultiModalKwargs": for key, elem in item.items():
elems_by_key[key].append(elem)
data = { data = {
key: items[0].field.reduce(items).data key: elems[0].field.reduce(elems).data
for key, items in items_by_key.items() if len(items) > 0 for key, elems in elems_by_key.items() if len(elems) > 0
} }
return MultiModalKwargs(data, return MultiModalKwargs(data, items=items)
items_by_key=items_by_key,
enable_sanity_checks=enable_sanity_checks)
def __init__( def __init__(
self, self,
data: Mapping[str, NestedTensors], data: Mapping[str, NestedTensors],
*, *,
items_by_key: Mapping[str, list[MultiModalFieldItem]] = {}, items: Optional[Sequence[MultiModalKwargsItem]] = None,
enable_sanity_checks: bool = False,
) -> None: ) -> None:
super().__init__(data) super().__init__(data)
# Shallow copy to avoid footgun in case a defaultdict is passed in items_by_modality = full_groupby(items or [], key=lambda x: x.modality)
self._items_by_key = dict(items_by_key) self._items_by_modality = dict(items_by_modality)
keys_by_modality = defaultdict[str, set[str]](set) @property
for key, items in items_by_key.items(): def modalities(self):
for item in items: return self._items_by_modality.keys()
keys_by_modality[item.field.modality].add(key)
self._keys_by_modality = dict(keys_by_modality)
if enable_sanity_checks:
for modality, keys in keys_by_modality.items():
items_in_modality = {k: items_by_key[k] for k in keys}
batch_sizes = {k: len(v) for k, v in items_in_modality.items()}
batch_size = next(iter(batch_sizes.values()), 0)
assert all(bs == batch_size
for bs in batch_sizes.values()), dict(
modality=modality,
batch_sizes=batch_sizes,
items_by_key=items_by_key)
@staticmethod @staticmethod
def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: def _try_stack(nested_tensors: NestedTensors) -> NestedTensors:
@@ -452,58 +446,44 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
if self._items_by_key != other._items_by_key: if self._items_by_modality != other._items_by_modality:
return False return False
ks = self.keys() ks = self.keys()
return (ks == other.keys() return (ks == other.keys()
and all(nested_tensors_equal(self[k], other[k]) for k in ks)) and all(nested_tensors_equal(self[k], other[k]) for k in ks))
def get_item(self, key: str, item_index: int) -> MultiModalFieldItem: def _validate_modality(self, method_name: str, modality: str) -> None:
return self._items_by_key[key][item_index] if not self._items_by_modality:
raise RuntimeError(
f"`{method_name}` is not supported when "
"MultiModalKwargs is not initialized with `items`")
def get_items_by_modality( if modality not in self._items_by_modality:
self, available_modalities = set(self._items_by_modality.keys())
modality: str, raise KeyError(f"Modality {modality!r} not found. "
item_index: int, f"Available modalities: {available_modalities}")
) -> Mapping[str, MultiModalFieldItem]:
def get_item_count(self, modality: str) -> int:
"""Get the number of items belonging to a modality."""
self._validate_modality("get_item_count", modality)
return len(self._items_by_modality[modality])
def get_item(self, modality: str, item_index: int) -> MultiModalKwargsItem:
""" """
Get the keyword arguments corresponding to an item identified by Get the keyword arguments corresponding to an item identified by
its modality and index. its modality and index.
""" """
if modality not in self._keys_by_modality: self._validate_modality("get_item", modality)
available_modalities = set(self._keys_by_modality.keys()) return self._items_by_modality[modality][item_index]
raise KeyError(f"Modality {modality!r} not found. "
f"Available modalities: {available_modalities}")
keys_to_gather = self._keys_by_modality[modality] def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
return {
key: self.get_item(key, item_index)
for key in keys_to_gather if key in self
}
@staticmethod
def from_items_by_modality(
items_by_modality: Mapping[str, list[Mapping[str,
MultiModalFieldItem]]],
*,
enable_sanity_checks: bool = False,
) -> "MultiModalKwargs":
""" """
Construct a new :class:`MultiModalKwargs` from multiple items returned Get the keyword arguments corresponding to each item belonging to
by :meth:`get_fields_by_modality`. a modality.
""" """
items_by_key = defaultdict[str, list[MultiModalFieldItem]](list) self._validate_modality("get_items", modality)
for fields in items_by_modality.values(): return self._items_by_modality[modality]
for field in fields:
for k, v in field.items():
items_by_key[k].append(v)
return MultiModalKwargs.from_items_by_key(
items_by_key,
enable_sanity_checks=enable_sanity_checks,
)
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]] MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]

View File

@@ -20,8 +20,8 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import LRUCache, flatten_2d_lists, full_groupby from vllm.utils import LRUCache, flatten_2d_lists, full_groupby
from .inputs import (MultiModalDataDict, MultiModalFieldConfig, from .inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalFieldItem, MultiModalInputsV2, MultiModalKwargs, MultiModalInputsV2, MultiModalKwargs,
PlaceholderRange) MultiModalKwargsItem, PlaceholderRange)
from .parse import MultiModalDataItems, MultiModalDataParser from .parse import MultiModalDataItems, MultiModalDataParser
logger = init_logger(__name__) logger = init_logger(__name__)
@@ -496,8 +496,7 @@ class ProcessingCache:
# DEBUG: Set to None to disable # DEBUG: Set to None to disable
self.debug_cache_hit_ratio_steps: Optional[int] = None self.debug_cache_hit_ratio_steps: Optional[int] = None
self._cache = LRUCache[str, Mapping[str, self._cache = LRUCache[str, MultiModalKwargsItem](capacity)
MultiModalFieldItem]](capacity)
def _maybe_log_cache_stats(self) -> None: def _maybe_log_cache_stats(self) -> None:
steps = self.debug_cache_hit_ratio_steps steps = self.debug_cache_hit_ratio_steps
@@ -565,7 +564,7 @@ class ProcessingCache:
modality: str, modality: str,
input_item: object, input_item: object,
input_kwargs: Mapping[str, object], input_kwargs: Mapping[str, object],
) -> Optional[Mapping[str, MultiModalFieldItem]]: ) -> Optional[MultiModalKwargsItem]:
""" """
Get a processed multi-modal item from the cache Get a processed multi-modal item from the cache
according to its dependencies, including: according to its dependencies, including:
@@ -588,7 +587,7 @@ class ProcessingCache:
modality: str, modality: str,
input_item: object, input_item: object,
input_kwargs: Mapping[str, object], input_kwargs: Mapping[str, object],
output_kwargs: Mapping[str, MultiModalFieldItem], output_kwargs: MultiModalKwargsItem,
) -> None: ) -> None:
""" """
Put a processed multi-modal item into the cache Put a processed multi-modal item into the cache
@@ -784,7 +783,6 @@ class BaseMultiModalProcessor(ABC):
mm_kwargs = MultiModalKwargs.from_hf_inputs( mm_kwargs = MultiModalKwargs.from_hf_inputs(
processed_data, processed_data,
self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs), self._get_mm_fields_config(processed_data, hf_processor_mm_kwargs),
enable_sanity_checks=self.enable_sanity_checks,
) )
return prompt_ids, mm_kwargs return prompt_ids, mm_kwargs
@@ -846,7 +844,7 @@ class BaseMultiModalProcessor(ABC):
hf_processor_mm_kwargs=hf_processor_mm_kwargs, hf_processor_mm_kwargs=hf_processor_mm_kwargs,
) )
mm_maybe_cached_field_items = { mm_maybe_cached_kw_items = {
modality: [ modality: [
cache.get(model_id, modality, item, hf_processor_mm_kwargs) cache.get(model_id, modality, item, hf_processor_mm_kwargs)
for item in items for item in items
@@ -855,8 +853,9 @@ class BaseMultiModalProcessor(ABC):
} }
mm_missing_idxs = { mm_missing_idxs = {
modality: [idx for idx, out in enumerate(fields) if out is None] modality:
for modality, fields in mm_maybe_cached_field_items.items() [idx for idx, item in enumerate(kw_items) if item is None]
for modality, kw_items in mm_maybe_cached_kw_items.items()
} }
mm_missing_data = { mm_missing_data = {
modality: [mm_data_items[modality][idx] for idx in idxs] modality: [mm_data_items[modality][idx] for idx in idxs]
@@ -875,14 +874,11 @@ class BaseMultiModalProcessor(ABC):
for modality in mm_missing_data_items for modality in mm_missing_data_items
} }
mm_merged_field_items = dict[str, list[Mapping[str, merged_kw_items = list[MultiModalKwargsItem]()
MultiModalFieldItem]]]() for modality, kw_items in mm_maybe_cached_kw_items.items():
for modality, modal_items_lst in mm_maybe_cached_field_items.items(): for idx, kw_item in enumerate(kw_items):
merged_modal_items_lst = list[Mapping[str, MultiModalFieldItem]]() if kw_item is None:
kw_item = mm_missing_kwargs.get_item(
for idx, modal_items in enumerate(modal_items_lst):
if modal_items is None:
modal_items = mm_missing_kwargs.get_items_by_modality(
modality, modality,
mm_missing_next_idx[modality], mm_missing_next_idx[modality],
) )
@@ -892,14 +888,12 @@ class BaseMultiModalProcessor(ABC):
modality, modality,
mm_data_items[modality][idx], mm_data_items[modality][idx],
hf_processor_mm_kwargs, hf_processor_mm_kwargs,
modal_items, kw_item,
) )
mm_missing_next_idx[modality] += 1 mm_missing_next_idx[modality] += 1
merged_modal_items_lst.append(modal_items) merged_kw_items.append(kw_item)
mm_merged_field_items[modality] = merged_modal_items_lst
if self.enable_sanity_checks: if self.enable_sanity_checks:
mm_missing_counts = mm_missing_data_items.get_all_counts() mm_missing_counts = mm_missing_data_items.get_all_counts()
@@ -909,10 +903,7 @@ class BaseMultiModalProcessor(ABC):
mm_missing_next_idx=mm_missing_next_idx, mm_missing_next_idx=mm_missing_next_idx,
mm_missing_counts=mm_missing_counts) mm_missing_counts=mm_missing_counts)
mm_kwargs = MultiModalKwargs.from_items_by_modality( mm_kwargs = MultiModalKwargs.from_items(merged_kw_items)
mm_merged_field_items,
enable_sanity_checks=self.enable_sanity_checks,
)
if self.enable_sanity_checks: if self.enable_sanity_checks:
mm_item_counts = mm_data_items.get_all_counts() mm_item_counts = mm_data_items.get_all_counts()
@@ -920,7 +911,7 @@ class BaseMultiModalProcessor(ABC):
for modality, item_count in mm_item_counts.items(): for modality, item_count in mm_item_counts.items():
for item_idx in range(item_count): for item_idx in range(item_count):
try: try:
mm_kwargs.get_items_by_modality(modality, item_idx) mm_kwargs.get_item(modality, item_idx)
except Exception as e: except Exception as e:
# Make it easy to set a breakpoint in the debugger # Make it easy to set a breakpoint in the debugger
raise e raise e

View File

@@ -113,15 +113,27 @@ class Processor:
# For merged preprocessor, mm_data is already mm_inputs # For merged preprocessor, mm_data is already mm_inputs
precomputed_mm_inputs = None precomputed_mm_inputs = None
if isinstance(decoder_inputs.multi_modal_data, MultiModalKwargs): decoder_mm_data = decoder_inputs.multi_modal_data
precomputed_mm_inputs = [decoder_inputs.multi_modal_data] if isinstance(decoder_mm_data, MultiModalKwargs):
# The output of merged multi-modal processor (`decoder_mm_data`)
# contains the kwargs for all items from all modalities.
# This code separates them so that there is one set of kwargs
# per item per modality.
precomputed_mm_inputs = [
MultiModalKwargs.from_items([item])
for modality in decoder_mm_data.modalities
for item in decoder_mm_data.get_items(modality)
]
# Apply MM mapper # Apply MM mapper
mm_inputs = None mm_inputs = None
if len(decoder_inputs.multi_modal_data) > 0: if len(decoder_mm_data) > 0:
mm_inputs = self.mm_input_mapper_client.process_inputs( mm_inputs = self.mm_input_mapper_client.process_inputs(
decoder_inputs.multi_modal_data, mm_hashes, decoder_mm_data,
decoder_inputs.mm_processor_kwargs, precomputed_mm_inputs) mm_hashes,
decoder_inputs.mm_processor_kwargs,
precomputed_mm_inputs,
)
return EngineCoreRequest( return EngineCoreRequest(
request_id, request_id,