[VLM] Use shared field to pass token ids to model

This commit is contained in:
Cyrus Leung
2025-02-06 05:30:46 +08:00
committed by GitHub
parent 3b2005e1db
commit a4ce74c14a
2 changed files with 235 additions and 46 deletions

View File

@@ -564,8 +564,7 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
# Since there may be extra tokens in the feature placeholders, # Since there may be extra tokens in the feature placeholders,
# we need to pass the image token ID to the model to select the # we need to pass the image token ID to the model to select the
# tokens to merge from the vision encoder outputs # tokens to merge from the vision encoder outputs
processed_outputs["image_token_id"] = [image_token_id processed_outputs["image_token_id"] = torch.tensor(image_token_id)
] * len(image_data)
return processed_outputs return processed_outputs
@@ -575,13 +574,14 @@ class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0)) image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
num_images = len(image_num_patches)
return dict( return dict(
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes( pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
"image", image_num_patches), "image", image_num_patches),
image_num_patches=MultiModalFieldConfig.batched("image"), image_num_patches=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), image_embeds=MultiModalFieldConfig.batched("image"),
image_token_id=MultiModalFieldConfig.batched("image"), image_token_id=MultiModalFieldConfig.shared("image", num_images),
) )
def _get_prompt_replacements( def _get_prompt_replacements(

View File

@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
from collections import UserDict, defaultdict from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial
from itertools import accumulate from itertools import accumulate
from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar, from typing import (TYPE_CHECKING, Any, Literal, Optional, TypedDict, TypeVar,
Union, cast, final) Union, cast, final)
@@ -164,51 +165,112 @@ A dictionary containing nested tensors which have been batched via
@dataclass(frozen=True) @dataclass(frozen=True)
class MultiModalFieldElem: class MultiModalFieldElem:
"""Contains metadata and data of an item in :class:`MultiModalKwargs`.""" """
field: "BaseMultiModalField" Represents a keyword argument corresponding to a multi-modal item
in :class:`MultiModalKwargs`.
"""
modality: str
"""
The modality of the corresponding multi-modal item.
Each multi-modal item can consist of multiple keyword arguments.
"""
key: str
"""
The key of this field in :class:`MultiModalKwargs`,
i.e. the name of the keyword argument to be passed to the model.
"""
data: NestedTensors data: NestedTensors
"""
The tensor data of this field in :class:`MultiModalKwargs`,
i.e. the value of the keyword argument to be passed to the model.
"""
field: "BaseMultiModalField"
"""
Defines how to combine the tensor data of this field with others
in order to batch multi-modal items together for model inference.
"""
def __eq__(self, other: object) -> bool: def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__): if not isinstance(other, self.__class__):
return False return False
return (self.field == other.field return ((self.modality, self.key) == (other.modality, other.key)
and nested_tensors_equal(self.data, other.data)) and nested_tensors_equal(self.data, other.data)
and type(self.field) == type(other.field)) # noqa: E721
@dataclass(frozen=True) @dataclass(frozen=True)
class BaseMultiModalField(ABC): class BaseMultiModalField(ABC):
"""Abstract base class for a field in :class:`MultiModalKwargs`.""" """
key: str Defines how to interpret tensor data belonging to a keyword argument in
modality: str :class:`MultiModalKwargs` for multiple multi-modal items, and vice versa.
"""
def _field_factory(self, *, modality: str, key: str):
f = partial(
MultiModalFieldElem,
modality=modality,
key=key,
field=self,
)
# Allow passing data as positional argument
def factory(data: NestedTensors) -> MultiModalFieldElem:
return f(data=data)
return factory
@abstractmethod
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
"""
Construct :class:`MultiModalFieldElem` instances to represent
the provided data.
This is the inverse of :meth:`reduce_data`.
"""
raise NotImplementedError
@abstractmethod @abstractmethod
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
raise NotImplementedError raise NotImplementedError
def _build_elem(self, data: NestedTensors) -> MultiModalFieldElem: def reduce_data(self, elems: list[MultiModalFieldElem]) -> NestedTensors:
return MultiModalFieldElem(self, data) """
Merge the data from multiple instances of :class:`MultiModalFieldElem`.
def reduce(self, batch: list[MultiModalFieldElem]) -> MultiModalFieldElem: This is the inverse of :meth:`build_elems`.
"""Merge multiple instances of :class:`MultiModalFieldElem` together.""" """
fields = [item.field for item in batch] field_types = [type(item.field) for item in elems]
if len(set(fields)) > 1: if len(set(field_types)) > 1:
raise ValueError(f"Cannot merge different {fields=}") raise ValueError(f"Cannot merge different {field_types=}")
data = self._reduce_data([item.data for item in batch]) return self._reduce_data([item.data for item in elems])
return self._build_elem(data)
@dataclass(frozen=True) @dataclass(frozen=True)
class MultiModalBatchedField(BaseMultiModalField): class MultiModalBatchedField(BaseMultiModalField):
""" """
A :class:`BaseMultiModalField` implementation where an element in the batch See also:
is obtained by indexing into the first dimension of the underlying data. :func:`MultiModalFieldConfig.batched`
""" """
def build_elems(self, batch: NestedTensors) -> list[MultiModalFieldElem]: def build_elems(
return [self._build_elem(item) for item in batch] self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(item) for item in data]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
@@ -227,16 +289,20 @@ class MultiModalBatchedField(BaseMultiModalField):
@dataclass(frozen=True) @dataclass(frozen=True)
class MultiModalFlatField(BaseMultiModalField): class MultiModalFlatField(BaseMultiModalField):
""" """
A :class:`BaseMultiModalField` implementation where an element in the batch See also:
is obtained by slicing along the first dimension of the underlying data. :func:`MultiModalFieldConfig.flat`
:func:`MultiModalFieldConfig.flat_from_sizes`
""" """
slices: Sequence[slice]
def build_elems( def build_elems(
self, self,
batch: NestedTensors, modality: str,
slices: Sequence[slice], key: str,
) -> list[MultiModalFieldElem]: data: NestedTensors,
return [self._build_elem(batch[slice_]) for slice_ in slices] ) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data[s]) for s in self.slices]
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors: def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"): if len(batch) > 0 and is_list_of(batch, torch.Tensor, check="all"):
@@ -252,25 +318,121 @@ class MultiModalFlatField(BaseMultiModalField):
return [e for elem in batch for e in elem] return [e for elem in batch for e in elem]
@dataclass(frozen=True)
class MultiModalSharedField(BaseMultiModalField):
"""
See also:
:func:`MultiModalFieldConfig.shared`
"""
batch_size: int
def build_elems(
self,
modality: str,
key: str,
data: NestedTensors,
) -> Sequence[MultiModalFieldElem]:
field_factory = self._field_factory(modality=modality, key=key)
return [field_factory(data)] * self.batch_size
def _reduce_data(self, batch: list[NestedTensors]) -> NestedTensors:
return batch[0]
class MultiModalFieldConfig: class MultiModalFieldConfig:
@staticmethod @staticmethod
def batched(modality: str): def batched(modality: str):
"""
Defines a field where an element in the batch is obtained by
indexing into the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
Example:
.. code-block::
Input:
Data: [[AAAA]
[BBBB]
[CCCC]]
Output:
Element 1: [AAAA]
Element 2: [BBBB]
Element 3: [CCCC]
"""
return MultiModalFieldConfig( return MultiModalFieldConfig(
field_cls=MultiModalBatchedField, field=MultiModalBatchedField(),
modality=modality, modality=modality,
) )
@staticmethod @staticmethod
def flat(modality: str, slices: Sequence[slice]): def flat(modality: str, slices: Sequence[slice]):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, a slice that is used to extract
the data corresponding to it.
Example:
.. code-block::
Given:
slices: [slice(0, 3), slice(3, 7), slice(7, 9)]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
"""
return MultiModalFieldConfig( return MultiModalFieldConfig(
field_cls=MultiModalFlatField, field=MultiModalFlatField(slices=slices),
modality=modality, modality=modality,
slices=slices,
) )
@staticmethod @staticmethod
def flat_from_sizes(modality: str, size_per_item: torch.Tensor): def flat_from_sizes(modality: str, size_per_item: torch.Tensor):
"""
Defines a field where an element in the batch is obtained by
slicing along the first dimension of the underlying data.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
slices: For each multi-modal item, the size of the slice that
is used to extract the data corresponding to it.
Example:
.. code-block::
Given:
size_per_item: [3, 4, 2]
Input:
Data: [AAABBBBCC]
Output:
Element 1: [AAA]
Element 2: [BBBB]
Element 3: [CC]
See also:
:func:`MultiModalFieldConfig.flat`
"""
slice_idxs = [0, *accumulate(size_per_item)] slice_idxs = [0, *accumulate(size_per_item)]
slices = [ slices = [
slice(slice_idxs[i], slice_idxs[i + 1]) slice(slice_idxs[i], slice_idxs[i + 1])
@@ -279,25 +441,52 @@ class MultiModalFieldConfig:
return MultiModalFieldConfig.flat(modality, slices) return MultiModalFieldConfig.flat(modality, slices)
def __init__( @staticmethod
self, def shared(modality: str, batch_size: int):
field_cls: type[BaseMultiModalField], """
modality: str, Defines a field where an element in the batch is obtained by
**field_config: Any, taking the entirety of the underlying data.
) -> None:
This means that the data is the same for each element in the batch.
Args:
modality: The modality of the multi-modal item that uses this
keyword argument.
batch_size: The number of multi-modal items which share this data.
Example:
.. code-block::
Given:
batch_size: 4
Input:
Data: [XYZ]
Output:
Element 1: [XYZ]
Element 2: [XYZ]
Element 3: [XYZ]
Element 4: [XYZ]
"""
return MultiModalFieldConfig(
field=MultiModalSharedField(batch_size),
modality=modality,
)
def __init__(self, field: BaseMultiModalField, modality: str) -> None:
super().__init__() super().__init__()
self.field_cls = field_cls self.field = field
self.modality = modality self.modality = modality
self.field_config = field_config
def build_elems( def build_elems(
self, self,
key: str, key: str,
batch: NestedTensors, batch: NestedTensors,
) -> Sequence[MultiModalFieldElem]: ) -> Sequence[MultiModalFieldElem]:
field = self.field_cls(key=key, modality=self.modality) return self.field.build_elems(self.modality, key, batch)
return field.build_elems(batch, **self.field_config) # type: ignore
class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]): class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
@@ -308,11 +497,11 @@ class MultiModalKwargsItem(UserDict[str, MultiModalFieldElem]):
@staticmethod @staticmethod
def from_elems(elems: Sequence[MultiModalFieldElem]): def from_elems(elems: Sequence[MultiModalFieldElem]):
return MultiModalKwargsItem({elem.field.key: elem for elem in elems}) return MultiModalKwargsItem({elem.key: elem for elem in elems})
@property @property
def modality(self) -> str: def modality(self) -> str:
modalities = {elem.field.modality for elem in self.data.values()} modalities = {elem.modality for elem in self.data.values()}
assert len(modalities) == 1, f"Found different modalities={modalities}" assert len(modalities) == 1, f"Found different modalities={modalities}"
return next(iter(modalities)) return next(iter(modalities))
@@ -372,7 +561,7 @@ class MultiModalKwargs(UserDict[str, NestedTensors]):
elems_by_key[key].append(elem) elems_by_key[key].append(elem)
data = { data = {
key: elems[0].field.reduce(elems).data key: elems[0].field.reduce_data(elems)
for key, elems in elems_by_key.items() if len(elems) > 0 for key, elems in elems_by_key.items() if len(elems) > 0
} }