[VLM] Use shared field to pass token ids to model
This commit is contained in:
@@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
# Since there may be extra tokens in the feature placeholders,
|
# Since there may be extra tokens in the feature placeholders,
|
||||||
# we need to pass the image token ID to the model to select the
|
# we need to pass the image token ID to the model to select the
|
||||||
# tokens to merge from the vision encoder outputs
|
# tokens to merge from the vision encoder outputs
|
||||||
processed_outputs["image_token_id"] = [image_token_id
|
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
||||||
] * len(image_data)
|
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
@@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||||
|
num_images = len(image_num_patches)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||||
"image", image_num_patches),
|
"image", image_num_patches),
|
||||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||||
image_token_id=MultiModalFieldConfig.batched("image"),
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_replacements(
|
def _get_prompt_replacements(
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
from collections import UserDict, defaultdict
|
from collections import UserDict, defaultdict
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from functools import partial
|
||||||
from itertools import accumulate
|
from itertools import accumulate
|
||||||
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
|
||||||
Union, cast, final)
|
Union, cast, final)
|
||||||
@@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
|
|||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MultiModalFieldElem:
|
class MultiModalFieldElem:
|
||||||
"""Contains metadata and data of an item in :class:`MultiModalKwargs`."""
|
"""
|
||||||
field: "BaseMultiModalField"
|
Represents a keyword argument corresponding to a multi-modal item
|
||||||
|
in :class:`MultiModalKwargs`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
modality: str
|
||||||
|
"""
|
||||||
|
The modality of the corresponding multi-modal item.
|
||||||
|
Each multi-modal item can consist of multiple keyword arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
"""
|
||||||
|
The key of this field in :class:`MultiModalKwargs`,
|
||||||
|
i.e. the name of the keyword argument to be passed to the model.
|
||||||
|
"""
|
||||||
|
|
||||||
data: NestedTensors
|
data: NestedTensors
|
||||||
|
"""
|
||||||
|
The tensor data of this field in :class:`MultiModalKwargs`,
|
||||||
|
i.e. the value of the keyword argument to be passed to the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
field: "BaseMultiModalField"
|
||||||
|
"""
|
||||||
|
Defines how to combine the tensor data of this field with others
|
||||||
|
in order to batch multi-modal items together for model inference.
|
||||||
|
"""
|
||||||
|
|
||||||
def __eq__(self, other: object) -> bool:
|
def __eq__(self, other: object) -> bool:
|
||||||
if not isinstance(other, self.__class__):
|
if not isinstance(other, self.__class__):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return (self.field == other.field
|
return ((self.modality, self.key) == (other.modality, other.key)
|
||||||
and nested_tensors_equal(self.data, other.data))
|
and nested_tensors_equal(self.data, other.data)
|
||||||
|
and type(self.field) == type(other.field)) # noqa: E721
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class BaseMultiModalField(ABC):
|
class BaseMultiModalField(ABC):
|
||||||
"""Abstract base class for a field in :class:`MultiModalKwargs`."""
|
"""
|
||||||
key: str
|
Defines how to interpret tensor data belonging to a keyword argument in
|
||||||
modality: str
|
:class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _field_factory(self, *, modality: str, key: str):
|
||||||
|
f = partial(
|
||||||
|
MultiModalFieldElem,
|
||||||
|
modality=modality,
|
||||||
|
key=key,
|
||||||
|
field=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Allow passing data as positional argument
|
||||||
|
def factory(data: NestedTensors) -> MultiModalFieldElem:
|
||||||
|
return f(data=data)
|
||||||
|
|
||||||
|
return factory
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def build_elems(
|
||||||
|
self,
|
||||||
|
modality: str,
|
||||||
|
key: str,
|
||||||
|
data: NestedTensors,
|
||||||
|
) -> Sequence[MultiModalFieldElem]:
|
||||||
|
"""
|
||||||
|
Construct :class:`MultiModalFieldElem` instances to represent
|
||||||
|
the provided data.
|
||||||
|
|
||||||
|
This is the inverse of :meth:`reduce_data`.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem:
|
def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
|
||||||
return MultiModalFieldElem(self, data)
|
"""
|
||||||
|
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
|
||||||
|
|
||||||
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem:
|
This is the inverse of :meth:`build_elems`.
|
||||||
"""Merge multiple instances of :class:`MultiModalFieldElem` together."""
|
"""
|
||||||
fields = [item.field for item in batch]
|
field_types = [type(item.field) for item in elems]
|
||||||
if len(set(fields)) > 1:
|
if len(set(field_types)) > 1:
|
||||||
raise ValueError(f"Cannot merge different {fields=}")
|
raise ValueError(f"Cannot merge different {field_types=}")
|
||||||
|
|
||||||
data = self._reduce_data([item.data for item in batch])
|
return self._reduce_data([item.data for item in elems])
|
||||||
|
|
||||||
return self._build_elem(data)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MultiModalBatchedField(BaseMultiModalField):
|
class MultiModalBatchedField(BaseMultiModalField):
|
||||||
"""
|
"""
|
||||||
A :class:`BaseMultiModalField` implementation where an element in the batch
|
See also:
|
||||||
is obtained by indexing into the first dimension of the underlying data.
|
:func:`MultiModalFieldConfig.batched`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]:
|
def build_elems(
|
||||||
return [self._build_elem(item) for item in batch]
|
self,
|
||||||
|
modality: str,
|
||||||
|
key: str,
|
||||||
|
data: NestedTensors,
|
||||||
|
) -> Sequence[MultiModalFieldElem]:
|
||||||
|
field_factory = self._field_factory(modality=modality, key=key)
|
||||||
|
return [field_factory(item) for item in data]
|
||||||
|
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
@@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
|
|||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class MultiModalFlatField(BaseMultiModalField):
|
class MultiModalFlatField(BaseMultiModalField):
|
||||||
"""
|
"""
|
||||||
A :class:`BaseMultiModalField` implementation where an element in the batch
|
See also:
|
||||||
is obtained by slicing along the first dimension of the underlying data.
|
:func:`MultiModalFieldConfig.flat`
|
||||||
|
:func:`MultiModalFieldConfig.flat_from_sizes`
|
||||||
"""
|
"""
|
||||||
|
slices: Sequence[slice]
|
||||||
|
|
||||||
def build_elems(
|
def build_elems(
|
||||||
self,
|
self,
|
||||||
batch: NestedTensors,
|
modality: str,
|
||||||
slices: Sequence[slice],
|
key: str,
|
||||||
) -> list[MultiModalFieldElem]:
|
data: NestedTensors,
|
||||||
return [self._build_elem(batch[slice_]) for slice_ in slices]
|
) -> Sequence[MultiModalFieldElem]:
|
||||||
|
field_factory = self._field_factory(modality=modality, key=key)
|
||||||
|
return [field_factory(data[s]) for s in self.slices]
|
||||||
|
|
||||||
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
|
||||||
@@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
|
|||||||
return [e for elem in batch for e in elem]
|
return [e for elem in batch for e in elem]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class MultiModalSharedField(BaseMultiModalField):
|
||||||
|
"""
|
||||||
|
See also:
|
||||||
|
:func:`MultiModalFieldConfig.shared`
|
||||||
|
"""
|
||||||
|
batch_size: int
|
||||||
|
|
||||||
|
def build_elems(
|
||||||
|
self,
|
||||||
|
modality: str,
|
||||||
|
key: str,
|
||||||
|
data: NestedTensors,
|
||||||
|
) -> Sequence[MultiModalFieldElem]:
|
||||||
|
field_factory = self._field_factory(modality=modality, key=key)
|
||||||
|
return [field_factory(data)] * self.batch_size
|
||||||
|
|
||||||
|
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
|
||||||
|
return batch[0]
|
||||||
|
|
||||||
|
|
||||||
class MultiModalFieldConfig:
|
class MultiModalFieldConfig:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def batched(modality: str):
|
def batched(modality: str):
|
||||||
|
"""
|
||||||
|
Defines a field where an element in the batch is obtained by
|
||||||
|
indexing into the first dimension of the underlying data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modality: The modality of the multi-modal item that uses this
|
||||||
|
keyword argument.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Data: [[AAAA]
|
||||||
|
[BBBB]
|
||||||
|
[CCCC]]
|
||||||
|
|
||||||
|
Output:
|
||||||
|
Element 1: [AAAA]
|
||||||
|
Element 2: [BBBB]
|
||||||
|
Element 3: [CCCC]
|
||||||
|
"""
|
||||||
return MultiModalFieldConfig(
|
return MultiModalFieldConfig(
|
||||||
field_cls=MultiModalBatchedField,
|
field=MultiModalBatchedField(),
|
||||||
modality=modality,
|
modality=modality,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flat(modality: str, slices: Sequence[slice]):
|
def flat(modality: str, slices: Sequence[slice]):
|
||||||
|
"""
|
||||||
|
Defines a field where an element in the batch is obtained by
|
||||||
|
slicing along the first dimension of the underlying data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modality: The modality of the multi-modal item that uses this
|
||||||
|
keyword argument.
|
||||||
|
slices: For each multi-modal item, a slice that is used to extract
|
||||||
|
the data corresponding to it.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Given:
|
||||||
|
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Data: [AAABBBBCC]
|
||||||
|
|
||||||
|
Output:
|
||||||
|
Element 1: [AAA]
|
||||||
|
Element 2: [BBBB]
|
||||||
|
Element 3: [CC]
|
||||||
|
"""
|
||||||
return MultiModalFieldConfig(
|
return MultiModalFieldConfig(
|
||||||
field_cls=MultiModalFlatField,
|
field=MultiModalFlatField(slices=slices),
|
||||||
modality=modality,
|
modality=modality,
|
||||||
slices=slices,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
|
def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
|
||||||
|
"""
|
||||||
|
Defines a field where an element in the batch is obtained by
|
||||||
|
slicing along the first dimension of the underlying data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modality: The modality of the multi-modal item that uses this
|
||||||
|
keyword argument.
|
||||||
|
slices: For each multi-modal item, the size of the slice that
|
||||||
|
is used to extract the data corresponding to it.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Given:
|
||||||
|
size_per_item: [3, 4, 2]
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Data: [AAABBBBCC]
|
||||||
|
|
||||||
|
Output:
|
||||||
|
Element 1: [AAA]
|
||||||
|
Element 2: [BBBB]
|
||||||
|
Element 3: [CC]
|
||||||
|
|
||||||
|
See also:
|
||||||
|
:func:`MultiModalFieldConfig.flat`
|
||||||
|
"""
|
||||||
|
|
||||||
slice_idxs = [0, *accumulate(size_per_item)]
|
slice_idxs = [0, *accumulate(size_per_item)]
|
||||||
slices = [
|
slices = [
|
||||||
slice(slice_idxs[i], slice_idxs[i + 1])
|
slice(slice_idxs[i], slice_idxs[i + 1])
|
||||||
@@ -279,25 +441,52 @@ class MultiModalFieldConfig:
|
|||||||
|
|
||||||
return MultiModalFieldConfig.flat(modality, slices)
|
return MultiModalFieldConfig.flat(modality, slices)
|
||||||
|
|
||||||
def __init__(
|
@staticmethod
|
||||||
self,
|
def shared(modality: str, batch_size: int):
|
||||||
field_cls: type[BaseMultiModalField],
|
"""
|
||||||
modality: str,
|
Defines a field where an element in the batch is obtained by
|
||||||
**field_config: Any,
|
taking the entirety of the underlying data.
|
||||||
) -> None:
|
|
||||||
|
This means that the data is the same for each element in the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
modality: The modality of the multi-modal item that uses this
|
||||||
|
keyword argument.
|
||||||
|
batch_size: The number of multi-modal items which share this data.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
Given:
|
||||||
|
batch_size: 4
|
||||||
|
|
||||||
|
Input:
|
||||||
|
Data: [XYZ]
|
||||||
|
|
||||||
|
Output:
|
||||||
|
Element 1: [XYZ]
|
||||||
|
Element 2: [XYZ]
|
||||||
|
Element 3: [XYZ]
|
||||||
|
Element 4: [XYZ]
|
||||||
|
"""
|
||||||
|
return MultiModalFieldConfig(
|
||||||
|
field=MultiModalSharedField(batch_size),
|
||||||
|
modality=modality,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.field_cls = field_cls
|
self.field = field
|
||||||
self.modality = modality
|
self.modality = modality
|
||||||
self.field_config = field_config
|
|
||||||
|
|
||||||
def build_elems(
|
def build_elems(
|
||||||
self,
|
self,
|
||||||
key: str,
|
key: str,
|
||||||
batch: NestedTensors,
|
batch: NestedTensors,
|
||||||
) -> Sequence[MultiModalFieldElem]:
|
) -> Sequence[MultiModalFieldElem]:
|
||||||
field = self.field_cls(key=key, modality=self.modality)
|
return self.field.build_elems(self.modality, key, batch)
|
||||||
return field.build_elems(batch, **self.field_config) # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
||||||
@@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
def from_elems(elems: Sequence[MultiModalFieldElem]):
|
||||||
return MultiModalKwargsItem({elem.field.key: elem for elem in elems})
|
return MultiModalKwargsItem({elem.key: elem for elem in elems})
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def modality(self) -> str:
|
def modality(self) -> str:
|
||||||
modalities = {elem.field.modality for elem in self.data.values()}
|
modalities = {elem.modality for elem in self.data.values()}
|
||||||
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
assert len(modalities) == 1, f"Found different modalities={modalities}"
|
||||||
return next(iter(modalities))
|
return next(iter(modalities))
|
||||||
|
|
||||||
@@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
|
|||||||
elems_by_key[key].append(elem)
|
elems_by_key[key].append(elem)
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
key: elems[0].field.reduce(elems).data
|
key: elems[0].field.reduce_data(elems)
|
||||||
for key, elems in elems_by_key.items() if len(elems) > 0
|
for key, elems in elems_by_key.items() if len(elems) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user