[CI/Build] Improve Tensor Schema tests speed by avoid engine core initialization (#23357)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
""" PyTorch Ovis model."""
|
||||
from collections.abc import Iterable, Mapping
|
||||
from functools import partial
|
||||
from typing import Optional, Union
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -50,6 +50,27 @@ IMAGE_PAD_TOKEN_ID_MAP = {
|
||||
}
|
||||
|
||||
|
||||
class OvisVideoPatchInputs(TypedDict):
|
||||
type: Literal["video_patches"]
|
||||
flat_data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
indicator_tokens: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * (num_patches + 1))`
|
||||
"""
|
||||
|
||||
patches_per_image: list[int]
|
||||
"""
|
||||
List of number of total patches for each frame in the video.
|
||||
This is used to restore the first two dimensions of `flat_data`.
|
||||
"""
|
||||
|
||||
|
||||
def _ovis2_5_field_config():
|
||||
return dict(pixel_values=MultiModalFieldConfig.batched("image"),
|
||||
grids=MultiModalFieldConfig.batched("image"),
|
||||
@@ -429,17 +450,11 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.get_language_model().make_empty_intermediate_tensors)
|
||||
|
||||
def _parse_and_validate_visual_input(
|
||||
self, is_video,
|
||||
**kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
if is_video:
|
||||
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||
grids = kwargs.pop("video_grids", None)
|
||||
else:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||
grids = kwargs.pop("grids", None)
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("indicator_tokens", None)
|
||||
grids = kwargs.pop("grids", None)
|
||||
if pixel_values is None and indicator_tokens is None:
|
||||
return None
|
||||
|
||||
@@ -466,8 +481,40 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _parse_and_validate_video_input(
|
||||
self, **kwargs: object) -> Optional[OvisImagePatchInputs]:
|
||||
pixel_values = kwargs.pop("video_pixel_values", None)
|
||||
indicator_tokens = kwargs.pop("video_indicator_tokens", None)
|
||||
grids = kwargs.pop("video_grids", None)
|
||||
if pixel_values is None and indicator_tokens is None:
|
||||
return None
|
||||
|
||||
if pixel_values is not None and indicator_tokens is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(indicator_tokens, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of indicator_tokens. "
|
||||
f"Got type: {type(indicator_tokens)}")
|
||||
|
||||
return OvisVideoPatchInputs(
|
||||
type="video_patches",
|
||||
flat_data=flatten_bn(flatten_bn(pixel_values), concat=True),
|
||||
patches_per_image=[
|
||||
x.shape[0] // (self.config.vit_config.hidden_stride**2)
|
||||
for x in flatten_bn(pixel_values)
|
||||
],
|
||||
indicator_tokens=flatten_bn(flatten_bn(indicator_tokens),
|
||||
concat=True),
|
||||
grids=flatten_bn(flatten_bn(grids), concat=True),
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
|
||||
self, image_input: Union[OvisImagePatchInputs, OvisVideoPatchInputs]
|
||||
) -> MultiModalEmbeddings:
|
||||
image_patches_flat = image_input["flat_data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
indicator_tokens = image_input["indicator_tokens"]
|
||||
@@ -500,21 +547,44 @@ class Ovis2_5(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
torch.cat(vision_embeddings_per_image, dim=0))
|
||||
return tuple(vision_embeddings)
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
embeddings = []
|
||||
def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
|
||||
modalities = {}
|
||||
|
||||
# NOTE: _parse_and_validate_visual_input has side-effects and pops
|
||||
# keys from kwargs. We process images first, then videos.
|
||||
image_input = self._parse_and_validate_visual_input(False, **kwargs)
|
||||
if image_input:
|
||||
embeddings.extend(self._process_image_input(image_input))
|
||||
# Preserve the order of modalities if there are multiple of them
|
||||
# from the order of kwargs.
|
||||
for input_key in kwargs:
|
||||
if input_key in ("pixel_values", "indicator_tokens",
|
||||
"grids") and "images" not in modalities:
|
||||
modalities["images"] = self._parse_and_validate_image_input(
|
||||
**kwargs)
|
||||
if input_key in ("video_pixel_values", "video_indicator_tokens",
|
||||
"video_grids") and "videos" not in modalities:
|
||||
modalities["videos"] = self._parse_and_validate_video_input(
|
||||
**kwargs)
|
||||
|
||||
video_input = self._parse_and_validate_visual_input(True, **kwargs)
|
||||
if video_input:
|
||||
embeddings.extend(self._process_image_input(video_input))
|
||||
return modalities
|
||||
|
||||
return tuple(embeddings) if embeddings else None
|
||||
def get_multimodal_embeddings(self,
|
||||
**kwargs: object) -> MultiModalEmbeddings:
|
||||
|
||||
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
||||
if not modalities:
|
||||
return []
|
||||
|
||||
multimodal_embeddings: tuple[torch.Tensor, ...] = ()
|
||||
# NOTE: It is important to iterate over the keys in this dictionary
|
||||
# to preserve the order of the modalities.
|
||||
for modality in modalities:
|
||||
if modality == "images":
|
||||
image_input = modalities["images"]
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
multimodal_embeddings += vision_embeddings
|
||||
if modality == "videos":
|
||||
video_input = modalities["videos"]
|
||||
video_embeddings = self._process_image_input(video_input)
|
||||
multimodal_embeddings += video_embeddings
|
||||
|
||||
return multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user