|
|
|
|
@@ -9,9 +9,10 @@
|
|
|
|
|
from collections.abc import Iterable, Mapping, Sequence
|
|
|
|
|
from typing import Literal, Optional, TypedDict, Union
|
|
|
|
|
|
|
|
|
|
import regex as re
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from transformers import InternVLProcessor, PretrainedConfig
|
|
|
|
|
from transformers import BatchFeature, InternVLProcessor, PretrainedConfig
|
|
|
|
|
from transformers.activations import ACT2FN
|
|
|
|
|
from transformers.models.got_ocr2.image_processing_got_ocr2_fast import (
|
|
|
|
|
GotOcr2ImageProcessorFast)
|
|
|
|
|
@@ -139,13 +140,13 @@ def get_interns1_target_ratios(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InternS1ProcessingInfo(BaseProcessingInfo):
|
|
|
|
|
"""Basic image-only ProcessingInfo for InternS1-style models."""
|
|
|
|
|
"""ProcessingInfo for InternS1-style models."""
|
|
|
|
|
|
|
|
|
|
def get_hf_processor(self, **kwargs: object) -> InternVLProcessor:
|
|
|
|
|
return self.ctx.get_hf_processor(InternVLProcessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
|
|
|
|
return {"image": None}
|
|
|
|
|
return {"image": None, "video": None}
|
|
|
|
|
|
|
|
|
|
def get_num_image_tokens(
|
|
|
|
|
self,
|
|
|
|
|
@@ -218,16 +219,35 @@ class InternS1ProcessingInfo(BaseProcessingInfo):
|
|
|
|
|
processor=processor.image_processor,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def get_num_frames_with_most_features(
|
|
|
|
|
self,
|
|
|
|
|
seq_len: int,
|
|
|
|
|
mm_counts: Mapping[str, int],
|
|
|
|
|
) -> int:
|
|
|
|
|
max_images = mm_counts.get("image", 0)
|
|
|
|
|
max_videos = mm_counts.get("video", 0)
|
|
|
|
|
|
|
|
|
|
processor = self.get_hf_processor()
|
|
|
|
|
|
|
|
|
|
max_image_tokens = self.get_max_image_tokens() * max_images
|
|
|
|
|
max_total_frames = (seq_len -
|
|
|
|
|
max_image_tokens) // processor.image_seq_length
|
|
|
|
|
max_frames_per_video = max_total_frames // max(max_videos, 1)
|
|
|
|
|
|
|
|
|
|
return max(max_frames_per_video, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
|
|
|
|
|
):
|
|
|
|
|
"""Basic image-only DummyInputsBuilder for InternS1-style models."""
|
|
|
|
|
"""DummyInputsBuilder for InternS1-style models."""
|
|
|
|
|
|
|
|
|
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
|
|
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
|
num_videos = mm_counts.get("video", 0)
|
|
|
|
|
image_token = self.info.get_hf_processor().image_token
|
|
|
|
|
video_token = self.info.get_hf_processor().video_token
|
|
|
|
|
|
|
|
|
|
return image_token * num_images
|
|
|
|
|
return image_token * num_images + video_token * num_videos
|
|
|
|
|
|
|
|
|
|
def get_dummy_mm_data(
|
|
|
|
|
self,
|
|
|
|
|
@@ -236,13 +256,24 @@ class InternS1DummyInputsBuilder(BaseDummyInputsBuilder[InternS1ProcessingInfo]
|
|
|
|
|
) -> MultiModalDataDict:
|
|
|
|
|
target_width, target_height = \
|
|
|
|
|
self.info.get_image_size_with_most_features()
|
|
|
|
|
target_num_frames = \
|
|
|
|
|
self.info.get_num_frames_with_most_features(seq_len, mm_counts)
|
|
|
|
|
num_images = mm_counts.get("image", 0)
|
|
|
|
|
num_videos = mm_counts.get("video", 0)
|
|
|
|
|
|
|
|
|
|
config = self.info.get_hf_config()
|
|
|
|
|
image_size_h, image_size_w = config.vision_config.image_size
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"image":
|
|
|
|
|
self._get_dummy_images(width=target_width,
|
|
|
|
|
height=target_height,
|
|
|
|
|
num_images=num_images)
|
|
|
|
|
num_images=num_images),
|
|
|
|
|
"video":
|
|
|
|
|
self._get_dummy_videos(width=image_size_w,
|
|
|
|
|
height=image_size_h,
|
|
|
|
|
num_frames=target_num_frames,
|
|
|
|
|
num_videos=num_videos),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -257,33 +288,89 @@ class InternS1MultiModalProcessor(
|
|
|
|
|
mm_kwargs: Mapping[str, object],
|
|
|
|
|
tok_kwargs: Mapping[str, object],
|
|
|
|
|
) -> Mapping[str, NestedTensors]:
|
|
|
|
|
processed_outputs = super()._call_hf_processor(
|
|
|
|
|
prompt=prompt,
|
|
|
|
|
mm_data=mm_data,
|
|
|
|
|
mm_kwargs=mm_kwargs,
|
|
|
|
|
tok_kwargs=tok_kwargs,
|
|
|
|
|
)
|
|
|
|
|
mm_data = dict(mm_data)
|
|
|
|
|
videos = mm_data.pop("videos", [])
|
|
|
|
|
images = mm_data.pop("images", [])
|
|
|
|
|
assert isinstance(videos, list)
|
|
|
|
|
assert isinstance(images, list)
|
|
|
|
|
|
|
|
|
|
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
|
|
|
|
image_token_id = hf_processor.image_token_id
|
|
|
|
|
tokenizer = hf_processor.tokenizer
|
|
|
|
|
video_token_id = tokenizer.encode(hf_processor.video_token,
|
|
|
|
|
add_special_tokens=False)
|
|
|
|
|
assert len(video_token_id) == 1
|
|
|
|
|
video_token_id = video_token_id[0]
|
|
|
|
|
|
|
|
|
|
# Since there may be extra tokens in the feature placeholders,
|
|
|
|
|
# we need to pass the image token ID to the model to select the
|
|
|
|
|
# tokens to merge from the vision encoder outputs
|
|
|
|
|
processed_outputs["image_token_id"] = torch.tensor(image_token_id)
|
|
|
|
|
images = mm_data.get('images', None)
|
|
|
|
|
image_processor = self.info.get_hf_processor().image_processor
|
|
|
|
|
if images is not None:
|
|
|
|
|
image_inputs = image_processor(images=images)
|
|
|
|
|
image_num_patches = image_inputs.pop("num_patches")
|
|
|
|
|
if not isinstance(image_num_patches, list):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f'num_patches is supposed to be list, but got '
|
|
|
|
|
f'{type(image_num_patches)}')
|
|
|
|
|
image_num_patches = torch.tensor(image_num_patches)
|
|
|
|
|
processed_outputs['image_num_patches'] = image_num_patches
|
|
|
|
|
prompt = re.sub(hf_processor.image_token, "<image_placeholder>",
|
|
|
|
|
prompt)
|
|
|
|
|
prompt = re.sub(hf_processor.video_token, "<video_placeholder>",
|
|
|
|
|
prompt)
|
|
|
|
|
|
|
|
|
|
return processed_outputs
|
|
|
|
|
image_outputs = {}
|
|
|
|
|
if images:
|
|
|
|
|
image_pixel_values = []
|
|
|
|
|
for image in images:
|
|
|
|
|
processed_outputs = super()._call_hf_processor(
|
|
|
|
|
prompt=hf_processor.image_token,
|
|
|
|
|
mm_data={"images": image},
|
|
|
|
|
mm_kwargs=mm_kwargs,
|
|
|
|
|
tok_kwargs=tok_kwargs,
|
|
|
|
|
)
|
|
|
|
|
image_pixel_values.append(
|
|
|
|
|
processed_outputs.pop("pixel_values"))
|
|
|
|
|
|
|
|
|
|
input_ids = processed_outputs.pop("input_ids")
|
|
|
|
|
image_placeholder = tokenizer.batch_decode(input_ids)[0]
|
|
|
|
|
prompt = prompt.replace("<image_placeholder>",
|
|
|
|
|
image_placeholder, 1)
|
|
|
|
|
|
|
|
|
|
num_patches = [len(item) for item in image_pixel_values]
|
|
|
|
|
image_outputs: dict[str, NestedTensors] = {
|
|
|
|
|
"pixel_values": torch.concat(image_pixel_values),
|
|
|
|
|
"image_num_patches": torch.tensor(num_patches),
|
|
|
|
|
"image_token_id": torch.tensor(hf_processor.image_token_id),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
video_outputs = {}
|
|
|
|
|
if videos:
|
|
|
|
|
video_pixel_values = []
|
|
|
|
|
for video in videos:
|
|
|
|
|
processed_outputs = super()._call_hf_processor(
|
|
|
|
|
prompt=hf_processor.video_token,
|
|
|
|
|
mm_data={"videos": video},
|
|
|
|
|
mm_kwargs=mm_kwargs,
|
|
|
|
|
tok_kwargs=tok_kwargs,
|
|
|
|
|
)
|
|
|
|
|
video_pixel_values.append(
|
|
|
|
|
processed_outputs.pop("pixel_values"))
|
|
|
|
|
|
|
|
|
|
input_ids = processed_outputs.pop("input_ids")
|
|
|
|
|
input_ids[input_ids ==
|
|
|
|
|
hf_processor.image_token_id] = video_token_id
|
|
|
|
|
|
|
|
|
|
video_placeholder = tokenizer.batch_decode(input_ids)[0]
|
|
|
|
|
prompt = prompt.replace("<video_placeholder>",
|
|
|
|
|
video_placeholder, 1)
|
|
|
|
|
|
|
|
|
|
num_frames = [len(item) for item in video_pixel_values]
|
|
|
|
|
video_outputs: dict[str, NestedTensors] = {
|
|
|
|
|
"pixel_values_videos": torch.concat(video_pixel_values),
|
|
|
|
|
"video_num_patches": torch.tensor(num_frames),
|
|
|
|
|
"video_token_id": torch.tensor(video_token_id),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
prompt = re.sub("<image_placeholder>", hf_processor.image_token,
|
|
|
|
|
prompt)
|
|
|
|
|
prompt = re.sub("<video_placeholder>", hf_processor.video_token,
|
|
|
|
|
prompt)
|
|
|
|
|
text_outputs = tokenizer(prompt, **tok_kwargs, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
combined_outputs = dict(
|
|
|
|
|
**text_outputs,
|
|
|
|
|
**image_outputs,
|
|
|
|
|
**video_outputs,
|
|
|
|
|
)
|
|
|
|
|
return BatchFeature(combined_outputs)
|
|
|
|
|
|
|
|
|
|
def _get_mm_fields_config(
|
|
|
|
|
self,
|
|
|
|
|
@@ -292,7 +379,9 @@ class InternS1MultiModalProcessor(
|
|
|
|
|
) -> Mapping[str, MultiModalFieldConfig]:
|
|
|
|
|
|
|
|
|
|
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
|
|
|
|
video_num_patches = hf_inputs.get("video_num_patches", torch.empty(0))
|
|
|
|
|
num_images = len(image_num_patches)
|
|
|
|
|
num_videos = len(video_num_patches)
|
|
|
|
|
|
|
|
|
|
return dict(
|
|
|
|
|
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
|
|
|
|
@@ -300,6 +389,10 @@ class InternS1MultiModalProcessor(
|
|
|
|
|
image_num_patches=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
image_embeds=MultiModalFieldConfig.batched("image"),
|
|
|
|
|
image_token_id=MultiModalFieldConfig.shared("image", num_images),
|
|
|
|
|
pixel_values_videos=MultiModalFieldConfig.flat_from_sizes(
|
|
|
|
|
"video", video_num_patches),
|
|
|
|
|
video_num_patches=MultiModalFieldConfig.batched("video"),
|
|
|
|
|
video_token_id=MultiModalFieldConfig.shared("video", num_videos),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _get_prompt_updates(
|
|
|
|
|
@@ -312,32 +405,61 @@ class InternS1MultiModalProcessor(
|
|
|
|
|
img_context_token = hf_processor.image_token
|
|
|
|
|
start_image_token = hf_processor.start_image_token
|
|
|
|
|
end_image_token = hf_processor.end_image_token
|
|
|
|
|
video_token = hf_processor.video_token
|
|
|
|
|
|
|
|
|
|
def get_replacement(item_idx: int):
|
|
|
|
|
if "video_num_patches" in out_mm_kwargs:
|
|
|
|
|
video_num_patches = out_mm_kwargs["video_num_patches"]
|
|
|
|
|
assert isinstance(video_num_patches, torch.Tensor)
|
|
|
|
|
video_num_patches = video_num_patches.tolist()
|
|
|
|
|
else:
|
|
|
|
|
video_num_patches = []
|
|
|
|
|
|
|
|
|
|
if "image_num_patches" in out_mm_kwargs:
|
|
|
|
|
image_num_patches = out_mm_kwargs["image_num_patches"]
|
|
|
|
|
assert isinstance(image_num_patches, torch.Tensor)
|
|
|
|
|
image_num_patches = image_num_patches.tolist()
|
|
|
|
|
else:
|
|
|
|
|
image_num_patches = []
|
|
|
|
|
|
|
|
|
|
def get_replacement_interns1_image(item_idx: int):
|
|
|
|
|
images = mm_items.get_items(
|
|
|
|
|
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
|
|
|
|
|
|
|
|
|
if isinstance(images, ImageEmbeddingItems):
|
|
|
|
|
feature_size = images.get_feature_size(item_idx)
|
|
|
|
|
else:
|
|
|
|
|
image_size = images.get_image_size(item_idx)
|
|
|
|
|
feature_size = self.info.get_num_image_tokens(
|
|
|
|
|
image_width=image_size.width,
|
|
|
|
|
image_height=image_size.height,
|
|
|
|
|
processor=hf_processor.image_processor,
|
|
|
|
|
)
|
|
|
|
|
num_patches = image_num_patches[item_idx]
|
|
|
|
|
feature_size = num_patches * hf_processor.image_seq_length
|
|
|
|
|
|
|
|
|
|
repl_features = img_context_token * feature_size
|
|
|
|
|
repl_full = start_image_token + repl_features + end_image_token
|
|
|
|
|
return PromptUpdateDetails.select_text(repl_full,
|
|
|
|
|
img_context_token)
|
|
|
|
|
|
|
|
|
|
def get_replacement_interns1_video(item_idx: int):
|
|
|
|
|
num_patches = video_num_patches[item_idx]
|
|
|
|
|
repl_features = video_token * hf_processor.image_seq_length
|
|
|
|
|
repl_features_with_sep = (start_image_token + repl_features +
|
|
|
|
|
end_image_token)
|
|
|
|
|
# num_patches is equal to num_frames
|
|
|
|
|
repl_full = '\n'.join([
|
|
|
|
|
f'Frame{i+1}: {repl_features_with_sep}'
|
|
|
|
|
for i in range(num_patches)
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
return PromptUpdateDetails.select_text(repl_full, video_token)
|
|
|
|
|
|
|
|
|
|
return [
|
|
|
|
|
PromptReplacement(
|
|
|
|
|
modality="image",
|
|
|
|
|
target=img_context_token,
|
|
|
|
|
replacement=get_replacement,
|
|
|
|
|
)
|
|
|
|
|
replacement=get_replacement_interns1_image,
|
|
|
|
|
),
|
|
|
|
|
PromptReplacement(
|
|
|
|
|
modality="video",
|
|
|
|
|
target=video_token,
|
|
|
|
|
replacement=get_replacement_interns1_video,
|
|
|
|
|
),
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -514,7 +636,7 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
|
|
|
|
|
|
|
|
def _parse_and_validate_video_input(
|
|
|
|
|
self, **kwargs: object) -> Optional[InternS1VideoPixelInputs]:
|
|
|
|
|
pixel_values_flat_video = kwargs.pop("pixel_values_flat_video", None)
|
|
|
|
|
pixel_values_flat_video = kwargs.pop("pixel_values_videos", None)
|
|
|
|
|
video_num_patches = kwargs.pop("video_num_patches", None)
|
|
|
|
|
video_embeds = kwargs.pop("video_embeds", None)
|
|
|
|
|
|
|
|
|
|
@@ -595,8 +717,8 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
|
|
|
"image_embeds") and "images" not in modalities:
|
|
|
|
|
modalities["images"] = self._parse_and_validate_image_input(
|
|
|
|
|
**kwargs)
|
|
|
|
|
if input_key in ("pixel_values_flat_video",
|
|
|
|
|
) and "videos" not in modalities:
|
|
|
|
|
if input_key in (
|
|
|
|
|
"pixel_values_videos", ) and "videos" not in modalities:
|
|
|
|
|
modalities["videos"] = self._parse_and_validate_video_input(
|
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
@@ -614,7 +736,6 @@ class InternS1ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|
|
|
|
modalities = self._parse_and_validate_multimodal_inputs(**kwargs)
|
|
|
|
|
if not modalities:
|
|
|
|
|
return []
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# The result multimodal_embeddings is tuple of tensors, with each
|
|
|
|
|
# tensor correspoending to a multimodal data item (image or video).
|
|
|
|
|
|