|
|
|
|
@@ -3,7 +3,7 @@
|
|
|
|
|
""" PyTorch Ovis model."""
|
|
|
|
|
from collections.abc import Iterable, Mapping
|
|
|
|
|
from functools import partial
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
from typing import Literal, Optional, TypedDict, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
@@ -50,6 +50,27 @@ IMAGE_PAD_TOKEN_ID_MAP = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OvisVideoPatchInputs(TypedDict):
|
|
|
|
|
type: Literal["video_patches"]
|
|
|
|
|
flat_data: torch.Tensor
|
|
|
|
|
"""
|
|
|
|
|
Shape:
|
|
|
|
|
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
indicator_tokens: torch.Tensor
|
|
|
|
|
"""
|
|
|
|
|
Shape:
|
|
|
|
|
`(batch_size * (num_patches + 1))`
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
patches_per_image: list[int]
|
|
|
|
|
"""
|
|
|
|
|
List of number of total patches for each frame in the video.
|
|
|
|
|
This is used to restore the first two dimensions of `flat_data`.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ovis2_5_field_config():
|
|
|
|
|
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
grids=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
@@ -429,17 +450,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|
|
|
|
self.make_empty_intermediate_tensors = (
|
|
|
|
|
self.get_language_model().make_empty_intermediate_tensors)
|
|
|
|
|
|
|
|
|
|
def _parse_and_validate_visual_input(
|
|
|
|
|
self, is_video,
|
|
|
|
|
**kwargs: object) -> Optional[OvisImagePatchInputs]:
|
|
|
|
|
if is_video:
|
|
|
|
|
pixel_values = kwargs.pop("video_pixel_values", None)
|
|
|
|
|
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
|
|
|
|
grids = kwargs.pop("video_grids", None)
|
|
|
|
|
else:
|
|
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
|
|
|
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
|
|
|
|
grids = kwargs.pop("grids", None)
|
|
|
|
|
def _parse_and_validate_image_input(
|
|
|
|
|
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
|
|
|
|
pixel_values = kwargs.pop("pixel_values", None)
|
|
|
|
|
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
|
|
|
|
grids = kwargs.pop("grids", None)
|
|
|
|
|
if pixel_values is None and indicator_tokens is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
@@ -466,8 +481,40 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|
|
|
|
|
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
|
|
|
|
|
|
def _parse_and_validate_video_input(
|
|
|
|
|
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
|
|
|
|
pixel_values = kwargs.pop("video_pixel_values", None)
|
|
|
|
|
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
|
|
|
|
grids = kwargs.pop("video_grids", None)
|
|
|
|
|
if pixel_values is None and indicator_tokens is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
if pixel_values is not None and indicator_tokens is not None:
|
|
|
|
|
if not isinstance(pixel_values, (torch.Tensor, list)):
|
|
|
|
|
raise ValueError("Incorrect type of pixel values. "
|
|
|
|
|
f"Got type: {type(pixel_values)}")
|
|
|
|
|
|
|
|
|
|
if not isinstance(indicator_tokens, (torch.Tensor, list)):
|
|
|
|
|
raise ValueError("Incorrect type of indicator_tokens. "
|
|
|
|
|
f"Got type: {type(indicator_tokens)}")
|
|
|
|
|
|
|
|
|
|
return OvisVideoPatchInputs(
|
|
|
|
|
type="video_patches",
|
|
|
|
|
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
|
|
|
|
patches_per_image=[
|
|
|
|
|
x.shape[0] // (self.config.vit_config.hidden_stride**2)
|
|
|
|
|
for x in flatten_bn(pixel_values)
|
|
|
|
|
],
|
|
|
|
|
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
|
|
|
|
|
concat=True),
|
|
|
|
|
grids=flatten_bn(flatten_bn(grids), concat=True),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
raise AssertionError("This line should be unreachable.")
|
|
|
|
|
|
|
|
|
|
def _process_image_input(
|
|
|
|
|
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
|
|
|
|
|
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
|
|
|
|
|
) -> MultiModalEmbeddings:
|
|
|
|
|
image_patches_flat = image_input["flat_data"]
|
|
|
|
|
patches_per_image = image_input["patches_per_image"]
|
|
|
|
|
indicator_tokens = image_input["indicator_tokens"]
|
|
|
|
|
@@ -500,21 +547,44 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
|
|
|
|
torch.cat(vision_embeddings_per_image, dim=0))
|
|
|
|
|
return tuple(vision_embeddings)
|
|
|
|
|
|
|
|
|
|
def get_multimodal_embeddings(
|
|
|
|
|
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
|
|
|
|
embeddings = []
|
|
|
|
|
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
|
|
|
|
modalities = {}
|
|
|
|
|
|
|
|
|
|
# NOTE: _parse_and_validate_visual_input has side-effects and pops
|
|
|
|
|
# keys from kwargs. We process images first, then videos.
|
|
|
|
|
image_input = self._parse_and_validate_visual_input(False, **kwargs)
|
|
|
|
|
if image_input:
|
|
|
|
|
embeddings.extend(self._process_image_input(image_input))
|
|
|
|
|
# Preserve the order of modalities if there are multiple of them
|
|
|
|
|
# from the order of kwargs.
|
|
|
|
|
for input_key in kwargs:
|
|
|
|
|
if input_key in ("pixel_values", "indicator_tokens",
|
|
|
|
|
"grids") and "images" not in modalities:
|
|
|
|
|
modalities["images"] = self._parse_and_validate_image_input(
|
|
|
|
|
**kwargs)
|
|
|
|
|
if input_key in ("video_pixel_values", "video_indicator_tokens",
|
|
|
|
|
"video_grids") and "videos" not in modalities:
|
|
|
|
|
modalities["videos"] = self._parse_and_validate_video_input(
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
video_input = self._parse_and_validate_visual_input(True, **kwargs)
|
|
|
|
|
if video_input:
|
|
|
|
|
embeddings.extend(self._process_image_input(video_input))
|
|
|
|
|
return modalities
|
|
|
|
|
|
|
|
|
|
return tuple(embeddings) if embeddings else None
|
|
|
|
|
def get_multimodal_embeddings(self,
|
|
|
|
|
**kwargs: object) -> MultiModalEmbeddings:
|
|
|
|
|
|
|
|
|
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
|
|
|
|
if not modalities:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
|
|
|
|
# NOTE: It is important to iterate over the keys in this dictionary
|
|
|
|
|
# to preserve the order of the modalities.
|
|
|
|
|
for modality in modalities:
|
|
|
|
|
if modality == "images":
|
|
|
|
|
image_input = modalities["images"]
|
|
|
|
|
vision_embeddings = self._process_image_input(image_input)
|
|
|
|
|
multimodal_embeddings += vision_embeddings
|
|
|
|
|
if modality == "videos":
|
|
|
|
|
video_input = modalities["videos"]
|
|
|
|
|
video_embeddings = self._process_image_input(video_input)
|
|
|
|
|
multimodal_embeddings += video_embeddings
|
|
|
|
|
|
|
|
|
|
return multimodal_embeddings
|
|
|
|
|
|
|
|
|
|
def get_input_embeddings(
|
|
|
|
|
self,
|
|
|
|
|
|