[2/3] Refactor InternVL-based processors (#37324)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -7,24 +7,17 @@
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
from transformers import BatchFeature, TensorType
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
|
||||
from vllm.multimodal.image import convert_image_mode
|
||||
from vllm.multimodal.processing import PromptUpdateDetails
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
IMG_START = "<img>"
|
||||
IMG_END = "</img>"
|
||||
IMG_CONTEXT = "<IMG_CONTEXT>"
|
||||
from vllm.tokenizers.hf import HfTokenizer
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
@@ -33,7 +26,7 @@ IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size: int):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose(
|
||||
return T.Compose(
|
||||
[
|
||||
T.Lambda(lambda img: convert_image_mode(img, "RGB")),
|
||||
T.Resize(
|
||||
@@ -43,7 +36,6 @@ def build_transform(input_size: int):
|
||||
T.Normalize(mean=MEAN, std=STD),
|
||||
]
|
||||
)
|
||||
return transform
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
@@ -223,65 +215,20 @@ def video_to_pixel_values_internvl(
|
||||
return pixel_values
|
||||
|
||||
|
||||
class BaseInternVLProcessor(ABC):
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
The code to insert image tokens is based on:
|
||||
https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
|
||||
"""
|
||||
|
||||
class InternVLImageProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
image_size: int,
|
||||
min_dynamic_patch: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: bool,
|
||||
use_thumbnail: bool,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
image_size: int = config.vision_config.image_size
|
||||
patch_size: int = config.vision_config.patch_size
|
||||
|
||||
if min_dynamic_patch is None:
|
||||
min_dynamic_patch = config.min_dynamic_patch
|
||||
assert isinstance(min_dynamic_patch, int)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = config.max_dynamic_patch
|
||||
assert isinstance(max_dynamic_patch, int)
|
||||
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = config.dynamic_image_size
|
||||
assert isinstance(dynamic_image_size, bool)
|
||||
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
||||
)
|
||||
self.image_size = image_size
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail: bool = config.use_thumbnail
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def image_token_id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_image_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: int | None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
raise NotImplementedError
|
||||
self.use_thumbnail = use_thumbnail
|
||||
|
||||
def resolve_min_max_num(
|
||||
self,
|
||||
@@ -291,18 +238,14 @@ class BaseInternVLProcessor(ABC):
|
||||
dynamic_image_size: bool | None = None,
|
||||
use_thumbnail: bool | None = None,
|
||||
) -> tuple[int, int]:
|
||||
min_dynamic_patch = (
|
||||
self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
|
||||
)
|
||||
max_dynamic_patch = (
|
||||
self.max_dynamic_patch if max_dynamic_patch is None else max_dynamic_patch
|
||||
)
|
||||
dynamic_image_size = (
|
||||
self.dynamic_image_size
|
||||
if dynamic_image_size is None
|
||||
else dynamic_image_size
|
||||
)
|
||||
use_thumbnail = self.use_thumbnail if use_thumbnail is None else use_thumbnail
|
||||
if min_dynamic_patch is None:
|
||||
min_dynamic_patch = self.min_dynamic_patch
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = self.max_dynamic_patch
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = self.dynamic_image_size
|
||||
if use_thumbnail is None:
|
||||
use_thumbnail = self.use_thumbnail
|
||||
|
||||
return resolve_internvl_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
@@ -311,43 +254,6 @@ class BaseInternVLProcessor(ABC):
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
def resolve_target_ratios(
|
||||
self,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
use_thumbnail: bool | None = None,
|
||||
) -> list[tuple[int, int]]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
return get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
target_ratios = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
|
||||
num_patches, _, _ = calculate_internvl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
)
|
||||
|
||||
return num_patches * self.num_image_token
|
||||
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
@@ -355,7 +261,14 @@ class BaseInternVLProcessor(ABC):
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
if min_dynamic_patch is None:
|
||||
min_dynamic_patch = self.min_dynamic_patch
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = self.max_dynamic_patch
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = self.dynamic_image_size
|
||||
|
||||
min_num, max_num = resolve_internvl_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
@@ -373,49 +286,9 @@ class BaseInternVLProcessor(ABC):
|
||||
for image in images
|
||||
]
|
||||
|
||||
def _preprocess_image(
|
||||
self,
|
||||
text: list[str],
|
||||
images: list[Image.Image],
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> tuple[list[str], dict[str, torch.Tensor]]:
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(
|
||||
images,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
image_inputs = {
|
||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
||||
"image_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst]
|
||||
),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
|
||||
image_repl = self.get_image_repl(feature_size, num_patches)
|
||||
text = [t.replace("<image>", image_repl.full, 1) for t in text]
|
||||
return text, image_inputs
|
||||
|
||||
def _make_batch_input(self, input_item: _T | list[_T] | None = None) -> list[_T]:
|
||||
if input_item is None:
|
||||
input_item = []
|
||||
if not isinstance(input_item, list):
|
||||
input_item = [input_item]
|
||||
return input_item
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: str | list[str] | None = None,
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
images: Image.Image | list[Image.Image],
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
@@ -423,120 +296,173 @@ class BaseInternVLProcessor(ABC):
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
text = self._make_batch_input(text)
|
||||
images = self._make_batch_input(images)
|
||||
images_lst = [images] if not isinstance(images, list) else images
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
text=text,
|
||||
images=images,
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(
|
||||
images_lst,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
combined_outputs = {**text_inputs, **image_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
image_inputs = {
|
||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
||||
"image_num_patches": torch.tensor([len(item) for item in pixel_values_lst]),
|
||||
}
|
||||
return BatchFeature(image_inputs, tensor_type=return_tensors)
|
||||
|
||||
|
||||
class InternVLProcessor(BaseInternVLProcessor):
|
||||
"""
|
||||
HF Processor for InternVLChatModel with extended video processing logic.
|
||||
|
||||
Code for video processing is adapted from video example:
|
||||
https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
|
||||
"""
|
||||
|
||||
class InternVLVideoProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
video_token: str | None = None,
|
||||
image_size: int,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
config=config,
|
||||
tokenizer=tokenizer,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
# add extra video token for video processing
|
||||
self.video_token = video_token
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
@property
|
||||
def video_token_id(self) -> int | None:
|
||||
if self.video_token is None:
|
||||
return None
|
||||
return self.tokenizer.get_vocab().get(self.video_token, None)
|
||||
|
||||
@property
|
||||
def supports_video(self) -> bool:
|
||||
return self.video_token_id is not None
|
||||
self.image_size = image_size
|
||||
|
||||
def _videos_to_pixel_values_lst(
|
||||
self,
|
||||
videos: list[npt.NDArray],
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=1,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
video_to_pixel_values_internvl(
|
||||
video,
|
||||
input_size=self.image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
min_num=1,
|
||||
max_num=1,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
for video in videos
|
||||
]
|
||||
|
||||
def _preprocess_video(
|
||||
def __call__(
|
||||
self,
|
||||
text: list[str],
|
||||
videos: list[npt.NDArray],
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> tuple[list[str], dict[str, Any]]:
|
||||
if len(videos) == 0 or not self.supports_video:
|
||||
return text, {}
|
||||
videos: npt.NDArray | list[npt.NDArray],
|
||||
*,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
videos_lst = [videos] if not isinstance(videos, list) else videos
|
||||
|
||||
video_token = self.video_token
|
||||
assert video_token is not None
|
||||
pixel_values_lst = self._videos_to_pixel_values_lst(videos_lst)
|
||||
|
||||
pixel_values_lst_video = self._videos_to_pixel_values_lst(
|
||||
videos,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
video_inputs = {
|
||||
"pixel_values_flat_video": torch.cat(pixel_values_lst_video),
|
||||
"video_num_patches": torch.tensor(
|
||||
[len(item) for item in pixel_values_lst_video]
|
||||
),
|
||||
image_inputs = {
|
||||
"pixel_values_flat_video": torch.cat(pixel_values_lst),
|
||||
"video_num_patches": torch.tensor([len(item) for item in pixel_values_lst]),
|
||||
}
|
||||
return BatchFeature(image_inputs, tensor_type=return_tensors)
|
||||
|
||||
for pixel_values in pixel_values_lst_video:
|
||||
num_patches = pixel_values.shape[0]
|
||||
|
||||
video_repl = self.get_video_repl(
|
||||
self.num_image_token, num_patches, video_token
|
||||
)
|
||||
text = [t.replace("<video>", video_repl.full, 1) for t in text]
|
||||
return text, video_inputs
|
||||
class InternVLProcessor(ProcessorMixin):
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
The code to insert image tokens is based on:
|
||||
https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
|
||||
|
||||
Code for video processing is adapted from video example:
|
||||
https://huggingface.co/OpenGVLab/InternVL3-1B#inference-with-transformers
|
||||
"""
|
||||
|
||||
attributes = ["image_processor", "tokenizer", "video_processor"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_processor: InternVLImageProcessor,
|
||||
tokenizer: HfTokenizer,
|
||||
video_processor: InternVLVideoProcessor | None = None,
|
||||
*,
|
||||
image_seq_length: int,
|
||||
start_image_token: str = "<img>",
|
||||
end_image_token: str = "</img>",
|
||||
ctx_image_token: str = "<IMG_CONTEXT>",
|
||||
ctx_video_token: str | None = None,
|
||||
) -> None:
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
self.video_processor = video_processor
|
||||
|
||||
self.image_seq_length = image_seq_length
|
||||
self.start_image_token = start_image_token
|
||||
self.end_image_token = end_image_token
|
||||
self.ctx_image_token = ctx_image_token
|
||||
self.ctx_video_token = ctx_video_token
|
||||
|
||||
self.start_image_token_id = tokenizer.convert_tokens_to_ids(start_image_token)
|
||||
self.end_image_token_id = tokenizer.convert_tokens_to_ids(end_image_token)
|
||||
self.ctx_image_token_id = tokenizer.convert_tokens_to_ids(ctx_image_token)
|
||||
self.ctx_video_token_id = (
|
||||
None
|
||||
if ctx_video_token is None
|
||||
else tokenizer.convert_tokens_to_ids(ctx_video_token)
|
||||
)
|
||||
|
||||
def resolve_target_ratios(
|
||||
self,
|
||||
*,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
use_thumbnail: bool | None = None,
|
||||
) -> list[tuple[int, int]]:
|
||||
min_num, max_num = self.image_processor.resolve_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
return get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
image_processor = self.image_processor
|
||||
target_ratios = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
|
||||
num_patches, _, _ = calculate_internvl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=image_processor.image_size,
|
||||
target_ratios=target_ratios,
|
||||
use_thumbnail=image_processor.use_thumbnail,
|
||||
)
|
||||
|
||||
return num_patches * self.image_seq_length
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
num_patches: int | None,
|
||||
num_features: int | None = None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
if num_patches is None:
|
||||
assert num_features is not None
|
||||
else:
|
||||
num_features = num_patches * self.image_seq_length
|
||||
|
||||
repl_features = self.ctx_image_token * num_features
|
||||
repl_full = self.start_image_token + repl_features + self.end_image_token
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, self.ctx_image_token)
|
||||
|
||||
def get_video_repl(self, num_patches: int) -> PromptUpdateDetails[str]:
|
||||
assert self.ctx_video_token is not None
|
||||
|
||||
repl_features = self.ctx_video_token * self.image_seq_length
|
||||
repl_features_with_sep = (
|
||||
self.start_image_token + repl_features + self.end_image_token
|
||||
)
|
||||
# num_patches is equal to num_frames
|
||||
repl_full = "".join(
|
||||
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, self.ctx_video_token)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -550,54 +476,88 @@ class InternVLProcessor(BaseInternVLProcessor):
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
text = self._make_batch_input(text)
|
||||
images = self._make_batch_input(images)
|
||||
videos = self._make_batch_input(videos)
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(
|
||||
images=images,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
image_num_patches = image_inputs["image_num_patches"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
image_num_patches = []
|
||||
|
||||
text, image_inputs = self._preprocess_image(
|
||||
text=text,
|
||||
images=images,
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
if videos is not None:
|
||||
if self.video_processor is None:
|
||||
raise ValueError("This model does not support video inputs")
|
||||
|
||||
text, video_inputs = self._preprocess_video(
|
||||
text=text,
|
||||
videos=videos,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
video_inputs = self.video_processor(
|
||||
videos=videos,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
video_num_patches = video_inputs["video_num_patches"]
|
||||
else:
|
||||
video_inputs = {}
|
||||
video_num_patches = []
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
if text is not None:
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
if image_inputs:
|
||||
image_token = "<image>"
|
||||
image_index = 0
|
||||
processed_text = list[str]()
|
||||
replace_strings = list[str]()
|
||||
|
||||
for prompt in text:
|
||||
new_prompt = prompt
|
||||
|
||||
while image_token in new_prompt:
|
||||
new_prompt = new_prompt.replace(image_token, "<placeholder>", 1)
|
||||
image_repl = self.get_image_repl(image_num_patches[image_index])
|
||||
replace_strings.append(image_repl.full)
|
||||
image_index += 1
|
||||
|
||||
while "<placeholder>" in new_prompt:
|
||||
replace_str = replace_strings.pop(0)
|
||||
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
|
||||
|
||||
processed_text.append(new_prompt)
|
||||
|
||||
text = processed_text
|
||||
|
||||
if video_inputs:
|
||||
video_token = "<video>"
|
||||
video_index = 0
|
||||
processed_text = list[str]()
|
||||
replace_strings = list[str]()
|
||||
|
||||
assert video_token is not None
|
||||
|
||||
for prompt in text:
|
||||
new_prompt = prompt
|
||||
|
||||
while video_token in new_prompt:
|
||||
new_prompt = new_prompt.replace(video_token, "<placeholder>", 1)
|
||||
video_repl = self.get_video_repl(video_num_patches[video_index])
|
||||
replace_strings.append(video_repl.full)
|
||||
video_index += 1
|
||||
|
||||
while "<placeholder>" in new_prompt:
|
||||
replace_str = replace_strings.pop(0)
|
||||
new_prompt = new_prompt.replace("<placeholder>", replace_str, 1)
|
||||
|
||||
processed_text.append(new_prompt)
|
||||
|
||||
text = processed_text
|
||||
|
||||
text_inputs = self.tokenizer(text, return_tensors=return_tensors)
|
||||
else:
|
||||
text_inputs = {}
|
||||
|
||||
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
|
||||
def get_image_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: int | None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, IMG_CONTEXT)
|
||||
|
||||
def get_video_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: int | None,
|
||||
video_context_token: str = IMG_CONTEXT,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
if num_patches is None:
|
||||
raise NotImplementedError("Embedding inputs are not supported")
|
||||
|
||||
repl_features = video_context_token * self.num_image_token
|
||||
repl_features_with_sep = IMG_START + repl_features + IMG_END
|
||||
# num_patches is equal to num_frames
|
||||
repl_full = "".join(
|
||||
[f"Frame{i + 1}: {repl_features_with_sep}" for i in range(num_patches)]
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_text(repl_full, video_context_token)
|
||||
|
||||
Reference in New Issue
Block a user