[Refactor] Remove Molmo2 processor wrapper (#36667)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,7 +3,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property, partial
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from typing import Annotated, Any
|
||||
|
||||
@@ -14,14 +14,14 @@ import torch.nn.functional as F
|
||||
from PIL import ImageOps
|
||||
from PIL.Image import Image
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
BaseVideoProcessor,
|
||||
BatchFeature,
|
||||
PretrainedConfig,
|
||||
ProcessorMixin,
|
||||
TensorType,
|
||||
)
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
from transformers.video_utils import VideoInput, VideoMetadata
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -1337,12 +1337,14 @@ def exif_transpose(
|
||||
|
||||
def build_flat_image_bool_length(
|
||||
image_grids: torch.LongTensor,
|
||||
image_patch_id: int,
|
||||
low_res_image_start_id: int,
|
||||
image_start_id: int,
|
||||
image_col_id: int,
|
||||
image_end_id: int,
|
||||
hf_config: PretrainedConfig,
|
||||
) -> tuple[torch.LongTensor, torch.LongTensor]:
|
||||
image_patch_id = hf_config.image_patch_id
|
||||
low_res_image_start_id = hf_config.low_res_image_start_token_id
|
||||
image_start_id = hf_config.image_start_token_id
|
||||
image_col_id = hf_config.image_col_id
|
||||
image_end_id = hf_config.image_end_token_id
|
||||
|
||||
device = image_grids.device
|
||||
B = image_grids.shape[0]
|
||||
|
||||
@@ -1401,10 +1403,12 @@ def build_flat_image_bool_length(
|
||||
|
||||
def build_flat_video_bool_length(
|
||||
video_grids: torch.LongTensor,
|
||||
image_patch_id: int,
|
||||
frame_start_id: int,
|
||||
frame_end_id: int,
|
||||
hf_config: PretrainedConfig,
|
||||
) -> tuple[torch.LongTensor, torch.LongTensor]:
|
||||
image_patch_id = hf_config.image_patch_id
|
||||
frame_start_id = hf_config.frame_start_token_id
|
||||
frame_end_id = hf_config.frame_end_token_id
|
||||
|
||||
device = video_grids.device
|
||||
B = video_grids.shape[0]
|
||||
|
||||
@@ -1439,314 +1443,6 @@ def build_flat_video_bool_length(
|
||||
return flat, lengths
|
||||
|
||||
|
||||
class Molmo2ProcessorWrapper:
|
||||
"""
|
||||
Wraps :class:`Molmo2Processor` so that it can be called directly.
|
||||
"""
|
||||
|
||||
def __init__(self, processor: ProcessorMixin, hf_config: PretrainedConfig):
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
self.hf_config = hf_config
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
return self.processor.tokenizer.vocab # type: ignore
|
||||
|
||||
@cached_property
|
||||
def max_crops(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
max_crops = image_processor.max_crops
|
||||
assert isinstance(max_crops, int)
|
||||
|
||||
return max_crops
|
||||
|
||||
@cached_property
|
||||
def image_pooling_h(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_pooling_h = image_processor.pooling_size[0]
|
||||
assert isinstance(image_pooling_h, int)
|
||||
|
||||
return image_pooling_h
|
||||
|
||||
@cached_property
|
||||
def image_pooling_w(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_pooling_w = image_processor.pooling_size[1]
|
||||
assert isinstance(image_pooling_w, int)
|
||||
|
||||
return image_pooling_w
|
||||
|
||||
@cached_property
|
||||
def video_pooling_h(self) -> int:
|
||||
video_processor = self.processor.video_processor # type: ignore
|
||||
|
||||
video_pooling_h = video_processor.pooling_size[0]
|
||||
assert isinstance(video_pooling_h, int)
|
||||
|
||||
return video_pooling_h
|
||||
|
||||
@cached_property
|
||||
def video_pooling_w(self) -> int:
|
||||
video_processor = self.processor.video_processor # type: ignore
|
||||
|
||||
video_pooling_w = video_processor.pooling_size[1]
|
||||
assert isinstance(video_pooling_w, int)
|
||||
|
||||
return video_pooling_w
|
||||
|
||||
@cached_property
|
||||
def base_image_input_size(self) -> tuple[int, int]:
|
||||
if getattr(self.processor, "image_processor", None) is not None:
|
||||
processor = self.processor.image_processor # type: ignore
|
||||
else:
|
||||
processor = self.processor.video_processor # type: ignore
|
||||
|
||||
base_image_input_size = (processor.size["height"], processor.size["width"])
|
||||
|
||||
return base_image_input_size
|
||||
|
||||
@cached_property
|
||||
def image_patch_size(self) -> int:
|
||||
if getattr(self.processor, "image_processor", None) is not None:
|
||||
processor = self.processor.image_processor # type: ignore
|
||||
else:
|
||||
processor = self.processor.video_processor # type: ignore
|
||||
|
||||
image_patch_size = processor.patch_size
|
||||
assert isinstance(image_patch_size, int)
|
||||
|
||||
return image_patch_size
|
||||
|
||||
@cached_property
|
||||
def overlap_margins(self) -> tuple[int, int]:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
assert isinstance(left_margin, int)
|
||||
assert isinstance(right_margin, int)
|
||||
|
||||
return left_margin, right_margin
|
||||
|
||||
@cached_property
|
||||
def bos_token(self) -> str:
|
||||
return self.processor.tokenizer.bos_token or self.processor.tokenizer.eos_token
|
||||
|
||||
@cached_property
|
||||
def image_patch_id(self) -> int:
|
||||
return self.hf_config.image_patch_id
|
||||
|
||||
@cached_property
|
||||
def im_col_id(self) -> int:
|
||||
return self.hf_config.image_col_id
|
||||
|
||||
@cached_property
|
||||
def im_start_id(self) -> int:
|
||||
return self.hf_config.image_start_token_id
|
||||
|
||||
@cached_property
|
||||
def im_end_id(self) -> int:
|
||||
return self.hf_config.image_end_token_id
|
||||
|
||||
@cached_property
|
||||
def low_res_im_start_id(self) -> int:
|
||||
return self.hf_config.low_res_image_start_token_id
|
||||
|
||||
@cached_property
|
||||
def frame_start_id(self) -> int:
|
||||
return self.hf_config.frame_start_token_id
|
||||
|
||||
@cached_property
|
||||
def frame_end_id(self) -> int:
|
||||
return self.hf_config.frame_end_token_id
|
||||
|
||||
@cached_property
|
||||
def im_low_res_id(self) -> int:
|
||||
return self.hf_config.image_low_res_id
|
||||
|
||||
@cached_property
|
||||
def image_placeholder_id(self) -> int:
|
||||
return self.vocab[IMAGE_PROMPT]
|
||||
|
||||
@cached_property
|
||||
def video_placeholder_id(self) -> int:
|
||||
return self.vocab[VIDEO_PROMPT]
|
||||
|
||||
@cached_property
|
||||
def image_token_ids(self) -> list[int]:
|
||||
return [
|
||||
self.image_patch_id,
|
||||
self.im_col_id,
|
||||
self.im_start_id,
|
||||
self.low_res_im_start_id,
|
||||
self.frame_start_id,
|
||||
self.im_end_id,
|
||||
self.frame_end_id,
|
||||
self.im_low_res_id,
|
||||
]
|
||||
|
||||
def select_tiling(
|
||||
self,
|
||||
*,
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
) -> tuple[int, int]:
|
||||
max_crops = self.max_crops
|
||||
left_margin, right_margin = self.overlap_margins
|
||||
base_image_input_size = self.base_image_input_size
|
||||
base_image_input_d = self.image_patch_size
|
||||
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin)
|
||||
crop_window_size = crop_window_patches * base_image_input_d
|
||||
tiling_h, tiling_w = select_tiling(
|
||||
height=image_height - total_margin_pixels,
|
||||
width=image_width - total_margin_pixels,
|
||||
patch_size=crop_window_size,
|
||||
max_num_patches=max_crops,
|
||||
)
|
||||
|
||||
return tiling_h, tiling_w
|
||||
|
||||
def get_base_grid_size(self, is_video: bool) -> tuple[int, int]:
|
||||
base_image_input_size = self.base_image_input_size
|
||||
|
||||
return get_patches_grid_size(
|
||||
image_h=base_image_input_size[0],
|
||||
image_w=base_image_input_size[1],
|
||||
patch_size=self.image_patch_size,
|
||||
pool_h=self.video_pooling_h if is_video else self.image_pooling_h,
|
||||
pool_w=self.video_pooling_w if is_video else self.image_pooling_w,
|
||||
)
|
||||
|
||||
def get_patches_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
) -> tuple[int, int]:
|
||||
left_margin, right_margin = self.overlap_margins
|
||||
base_image_input_size = self.base_image_input_size
|
||||
base_image_input_d = self.image_patch_size
|
||||
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin)
|
||||
crop_window_size = crop_window_patches * base_image_input_d
|
||||
|
||||
tiling_h, tiling_w = self.select_tiling(
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
)
|
||||
|
||||
h, w = [
|
||||
tiling_h * crop_window_size + total_margin_pixels,
|
||||
tiling_w * crop_window_size + total_margin_pixels,
|
||||
]
|
||||
nrows, ncols = get_patches_grid_size(
|
||||
image_h=h,
|
||||
image_w=w,
|
||||
patch_size=base_image_input_d,
|
||||
pool_h=self.image_pooling_h,
|
||||
pool_w=self.image_pooling_w,
|
||||
)
|
||||
|
||||
return nrows, ncols
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | None = None,
|
||||
videos: VideoInput | None = None,
|
||||
return_tensors: str | TensorType = None,
|
||||
**kwargs: object,
|
||||
) -> BatchFeature:
|
||||
inputs = [text]
|
||||
images = exif_transpose(images)
|
||||
if getattr(self.processor, "image_processor", None) is not None:
|
||||
inputs.append(images)
|
||||
if getattr(self.processor, "video_processor", None) is not None:
|
||||
inputs.append(videos)
|
||||
outputs = self.processor( # type: ignore
|
||||
*inputs,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# revert insert bos token
|
||||
if outputs["input_ids"][0, 0] == self.vocab[self.bos_token]:
|
||||
outputs["input_ids"] = outputs["input_ids"][:, 1:]
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
if videos is None:
|
||||
videos = []
|
||||
if not isinstance(videos, list):
|
||||
videos = [videos]
|
||||
|
||||
assert len(videos) in {0, 1}, "At most one video is supported for Molmo2"
|
||||
|
||||
_attention_mask: torch.Tensor = outputs.pop("attention_mask")
|
||||
_token_type_ids: torch.Tensor = outputs.pop("token_type_ids", None)
|
||||
|
||||
if len(images) > 0:
|
||||
# For each image: tiling_h * tiling_w + global view
|
||||
num_crops = []
|
||||
for image in images:
|
||||
image_size = get_image_size(image)
|
||||
tiling = self.select_tiling(
|
||||
image_height=image_size.height,
|
||||
image_width=image_size.width,
|
||||
)
|
||||
num_crops.append(np.prod(tiling) + 1)
|
||||
|
||||
assert sum(num_crops) == len(outputs["pixel_values"])
|
||||
assert sum(num_crops) == outputs["image_num_crops"].sum().item()
|
||||
image_grids: torch.Tensor = outputs.pop("image_grids")
|
||||
image_num_pooled_patches: torch.Tensor = image_grids[:, :2].prod(
|
||||
dim=1
|
||||
) + image_grids[:, 2:].prod(dim=1)
|
||||
outputs["image_num_pooled_patches"] = image_num_pooled_patches
|
||||
n_patches = outputs["pixel_values"].shape[1]
|
||||
outputs["image_num_patches"] = outputs["image_num_crops"] * n_patches
|
||||
image_tokens, num_image_tokens = build_flat_image_bool_length(
|
||||
image_grids,
|
||||
self.image_patch_id,
|
||||
self.low_res_im_start_id,
|
||||
self.im_start_id,
|
||||
self.im_col_id,
|
||||
self.im_end_id,
|
||||
)
|
||||
outputs["image_tokens"] = image_tokens
|
||||
outputs["num_image_tokens"] = num_image_tokens
|
||||
|
||||
if len(videos) > 0:
|
||||
video_grids: torch.Tensor = outputs.pop("video_grids")
|
||||
assert video_grids[:, 0].sum() == len(outputs["pixel_values_videos"])
|
||||
outputs["video_num_crops"] = video_grids[:, 0]
|
||||
outputs["video_num_pooled_patches"] = video_grids.prod(dim=1)
|
||||
n_patches = outputs["pixel_values_videos"].shape[1]
|
||||
outputs["video_num_patches"] = outputs["video_num_crops"] * n_patches
|
||||
video_tokens, num_video_tokens = build_flat_video_bool_length(
|
||||
video_grids,
|
||||
self.image_patch_id,
|
||||
self.frame_start_id,
|
||||
self.frame_end_id,
|
||||
)
|
||||
outputs["video_tokens"] = video_tokens
|
||||
outputs["num_video_tokens"] = num_video_tokens
|
||||
|
||||
return BatchFeature(outputs)
|
||||
|
||||
|
||||
def get_candidate_target_fps(
|
||||
video_fps: int | float,
|
||||
sampling_fps: int | float,
|
||||
@@ -1856,36 +1552,101 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> Molmo2ProcessorWrapper:
|
||||
processor = self.ctx.get_hf_processor(**kwargs)
|
||||
hf_config = self.ctx.get_hf_config()
|
||||
return Molmo2ProcessorWrapper(processor, hf_config)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None, "video": 1}
|
||||
|
||||
def select_tiling(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> tuple[int, int]:
|
||||
max_crops = image_processor.max_crops
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
base_image_input_d = image_processor.patch_size
|
||||
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
||||
crop_patches = image_processor.size["height"] // base_image_input_d
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin)
|
||||
crop_window_size = crop_window_patches * base_image_input_d
|
||||
tiling_h, tiling_w = select_tiling(
|
||||
height=image_height - total_margin_pixels,
|
||||
width=image_width - total_margin_pixels,
|
||||
patch_size=crop_window_size,
|
||||
max_num_patches=max_crops,
|
||||
)
|
||||
|
||||
return tiling_w, tiling_h
|
||||
|
||||
def get_base_grid_size(
|
||||
self,
|
||||
image_processor: BaseImageProcessor | BaseVideoProcessor,
|
||||
) -> tuple[int, int]:
|
||||
nrows, ncols = get_patches_grid_size(
|
||||
image_h=image_processor.size["height"],
|
||||
image_w=image_processor.size["width"],
|
||||
patch_size=image_processor.patch_size,
|
||||
pool_h=image_processor.pooling_size[0],
|
||||
pool_w=image_processor.pooling_size[1],
|
||||
)
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def get_patches_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> tuple[int, int]:
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
base_image_input_d = image_processor.patch_size
|
||||
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
||||
crop_patches = image_processor.size["height"] // base_image_input_d
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin)
|
||||
crop_window_size = crop_window_patches * base_image_input_d
|
||||
|
||||
tiling_w, tiling_h = self.select_tiling(
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
nrows, ncols = get_patches_grid_size(
|
||||
image_h=tiling_h * crop_window_size + total_margin_pixels,
|
||||
image_w=tiling_w * crop_window_size + total_margin_pixels,
|
||||
patch_size=base_image_input_d,
|
||||
pool_h=image_processor.pooling_size[0],
|
||||
pool_w=image_processor.pooling_size[1],
|
||||
)
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
processor: Molmo2ProcessorWrapper,
|
||||
processor: ProcessorMixin,
|
||||
) -> int:
|
||||
hf_processor = processor.processor
|
||||
image_processor = processor.image_processor
|
||||
|
||||
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
|
||||
resize_ncols, resize_nrows = self.get_base_grid_size(image_processor)
|
||||
# start/end tokens + image patch token + col tokens
|
||||
if hf_processor.use_single_crop_col_tokens is not None:
|
||||
use_col_tokens = hf_processor.use_single_crop_col_tokens
|
||||
if processor.use_single_crop_col_tokens is not None:
|
||||
use_col_tokens = processor.use_single_crop_col_tokens
|
||||
else:
|
||||
use_col_tokens = hf_processor.image_use_col_tokens
|
||||
extra = 2 + resize_nrows * (resize_cols + int(use_col_tokens))
|
||||
overlap_nrows, overlap_ncols = processor.get_patches_grid_size(
|
||||
use_col_tokens = processor.image_use_col_tokens
|
||||
extra = 2 + resize_nrows * (resize_ncols + int(use_col_tokens))
|
||||
overlap_ncols, overlap_nrows = self.get_patches_grid_size(
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
joint = 2 + overlap_nrows * (
|
||||
overlap_ncols + int(hf_processor.image_use_col_tokens)
|
||||
overlap_ncols + int(processor.image_use_col_tokens)
|
||||
)
|
||||
|
||||
return extra + joint
|
||||
@@ -1894,28 +1655,28 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
self,
|
||||
*,
|
||||
num_frames: int,
|
||||
processor: Molmo2ProcessorWrapper,
|
||||
processor: ProcessorMixin,
|
||||
) -> int:
|
||||
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
|
||||
video_processor = processor.video_processor
|
||||
|
||||
resize_ncols, resize_nrows = self.get_base_grid_size(video_processor)
|
||||
# start/end tokens
|
||||
extra = 2 + resize_nrows * (
|
||||
resize_cols + int(processor.processor.video_use_col_tokens)
|
||||
)
|
||||
extra = 2 + resize_nrows * (resize_ncols + int(processor.video_use_col_tokens))
|
||||
return num_frames * extra
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
image_processor = processor.image_processor
|
||||
|
||||
left_margin, right_margin = processor.overlap_margins
|
||||
base_image_input_size = processor.base_image_input_size
|
||||
base_image_input_d = processor.image_patch_size
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
base_image_input_d = image_processor.patch_size
|
||||
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
crop_patches = image_processor.size["height"] // base_image_input_d
|
||||
crop_window_patches = crop_patches - (right_margin + left_margin)
|
||||
crop_window_size = crop_window_patches * base_image_input_d
|
||||
|
||||
tilings = get_candidate_tilings(processor.max_crops)
|
||||
tilings = get_candidate_tilings(image_processor.max_crops)
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
|
||||
for hr, wr in tilings:
|
||||
@@ -1939,7 +1700,7 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
def _get_max_video_frames(
|
||||
self,
|
||||
max_tokens: int,
|
||||
processor: Molmo2ProcessorWrapper,
|
||||
processor: ProcessorMixin,
|
||||
) -> int:
|
||||
num_tokens_per_frame = self.get_num_video_tokens(
|
||||
num_frames=1,
|
||||
@@ -1954,7 +1715,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> int:
|
||||
processor = self.get_hf_processor()
|
||||
video_processor = processor.processor.video_processor
|
||||
video_processor = processor.video_processor
|
||||
|
||||
num_frames = video_processor.num_frames
|
||||
max_videos = mm_counts.get("video", 0)
|
||||
max_total_frames = self._get_max_video_frames(seq_len, processor)
|
||||
@@ -2030,7 +1792,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
|
||||
metadata: dict[str, Any],
|
||||
do_sample_frames: bool | None = None,
|
||||
) -> list[float]:
|
||||
video_processor = self.get_hf_processor().processor.video_processor
|
||||
processor = self.get_hf_processor()
|
||||
video_processor = processor.video_processor
|
||||
|
||||
# metadata["fps"] refers to the true fps of the input video.
|
||||
video_fps = metadata["fps"]
|
||||
frames_indices = metadata.get("frames_indices")
|
||||
@@ -2104,7 +1868,7 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
|
||||
|
||||
if num_videos > 0:
|
||||
processor = self.info.get_hf_processor()
|
||||
base_image_input_size = processor.base_image_input_size
|
||||
video_size = processor.video_processor.size
|
||||
target_num_frames = self.info.get_num_frames_with_most_features(
|
||||
seq_len, mm_counts
|
||||
)
|
||||
@@ -2131,8 +1895,8 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
|
||||
target_num_frames = min(target_num_frames, num_frames_override)
|
||||
|
||||
dummy_videos = self._get_dummy_videos(
|
||||
width=base_image_input_size[1],
|
||||
height=base_image_input_size[0],
|
||||
width=video_size["width"],
|
||||
height=video_size["height"],
|
||||
num_frames=target_num_frames,
|
||||
num_videos=num_videos,
|
||||
)
|
||||
@@ -2174,10 +1938,10 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
prompt_tokens: list[int],
|
||||
) -> list[int]:
|
||||
processor = self.info.get_hf_processor()
|
||||
tokenizer = processor.processor.tokenizer
|
||||
tokenizer = processor.tokenizer
|
||||
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
|
||||
|
||||
if len(prompt_tokens) > 0 and prompt_tokens[0] != bos_token_id:
|
||||
if len(prompt_tokens) == 0 or prompt_tokens[0] != bos_token_id:
|
||||
# Prepend the bos token to the prompt tokens
|
||||
prompt_tokens = [bos_token_id] + prompt_tokens
|
||||
|
||||
@@ -2191,9 +1955,26 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
mm_data = dict(mm_data)
|
||||
processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
hf_config = self.info.get_hf_config()
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
|
||||
def patched_call(text=None, images=None, videos=None, **kwargs) -> BatchFeature:
|
||||
res = hf_processor(text=text, images=images, videos=videos, **kwargs)
|
||||
|
||||
# Molmo2Processor.insert_bos results in float outputs
|
||||
# if the input text is empty
|
||||
if not text:
|
||||
res["input_ids"] = res["input_ids"].long()
|
||||
|
||||
return res
|
||||
|
||||
tokenizer = hf_processor.tokenizer
|
||||
image_processor = hf_processor.image_processor
|
||||
|
||||
if videos := mm_data.pop("videos", []):
|
||||
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
|
||||
|
||||
pixel_values_videos_lst = []
|
||||
video_token_pooling_lst = []
|
||||
video_num_crops_lst = []
|
||||
@@ -2228,18 +2009,32 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
video_mm_data["videos"] = [[video_array]]
|
||||
video_mm_data["video_metadata"] = [[metadata]]
|
||||
|
||||
video_outputs = super()._call_hf_processor(
|
||||
prompt=VIDEO_PROMPT,
|
||||
mm_data=video_mm_data,
|
||||
mm_kwargs=video_mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
video_outputs = self.info.ctx.call_hf_processor(
|
||||
patched_call,
|
||||
dict(text=VIDEO_PROMPT, **video_mm_data),
|
||||
dict(**video_mm_kwargs, **tok_kwargs),
|
||||
)
|
||||
|
||||
input_ids = video_outputs.pop("input_ids")
|
||||
video_string = processor.processor.tokenizer.batch_decode(input_ids)[0]
|
||||
prompt = prompt.replace(
|
||||
VIDEO_PROMPT,
|
||||
video_string,
|
||||
1,
|
||||
if input_ids[0, 0] == bos_token_id:
|
||||
input_ids = input_ids[:, 1:]
|
||||
|
||||
video_string = tokenizer.batch_decode(input_ids)[0]
|
||||
prompt = prompt.replace(VIDEO_PROMPT, video_string, 1)
|
||||
|
||||
video_grids = video_outputs.pop("video_grids")
|
||||
assert video_grids[:, 0].sum() == len(
|
||||
video_outputs["pixel_values_videos"]
|
||||
)
|
||||
|
||||
video_outputs["video_num_crops"] = video_grids[:, 0]
|
||||
video_outputs["video_num_pooled_patches"] = video_grids.prod(dim=1)
|
||||
n_patches = video_outputs["pixel_values_videos"].shape[1]
|
||||
video_outputs["video_num_patches"] = (
|
||||
video_outputs["video_num_crops"] * n_patches
|
||||
)
|
||||
(video_outputs["video_tokens"], video_outputs["num_video_tokens"]) = (
|
||||
build_flat_video_bool_length(video_grids, hf_config)
|
||||
)
|
||||
|
||||
pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
|
||||
@@ -2252,7 +2047,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
video_tokens_lst.append(video_outputs["video_tokens"])
|
||||
num_video_tokens_lst.append(video_outputs["num_video_tokens"])
|
||||
|
||||
video_outputs = dict(
|
||||
all_video_outputs = dict(
|
||||
pixel_values_videos=torch.cat(pixel_values_videos_lst),
|
||||
video_token_pooling=torch.cat(video_token_pooling_lst),
|
||||
video_num_crops=torch.cat(video_num_crops_lst),
|
||||
@@ -2262,30 +2057,50 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
num_video_tokens=torch.cat(num_video_tokens_lst),
|
||||
)
|
||||
else:
|
||||
video_outputs = dict()
|
||||
all_video_outputs = dict()
|
||||
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
processed_outputs = self.info.ctx.call_hf_processor(
|
||||
patched_call,
|
||||
dict(text=prompt, **mm_data),
|
||||
dict(**mm_kwargs, **tok_kwargs),
|
||||
)
|
||||
|
||||
bos_token_id = processor.vocab[processor.bos_token]
|
||||
input_ids = processed_outputs["input_ids"]
|
||||
# add bos token back to prompt start
|
||||
if input_ids.numel() > 0 and input_ids[0, 0] != bos_token_id:
|
||||
bos_token_id_tensor = torch.tensor(
|
||||
[[bos_token_id]], device=input_ids.device, dtype=input_ids.dtype
|
||||
if (images := mm_data.get("images")) is not None:
|
||||
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
|
||||
parsed_images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_sizes = [
|
||||
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||
]
|
||||
|
||||
# For each image: tiling_h * tiling_w + global view
|
||||
tilings = [
|
||||
self.info.select_tiling(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||
assert sum(num_crops) == len(processed_outputs["pixel_values"])
|
||||
assert sum(num_crops) == processed_outputs["image_num_crops"].sum().item()
|
||||
|
||||
image_grids = processed_outputs.pop("image_grids")
|
||||
image_num_pooled_patches = image_grids[:, :2].prod(dim=1) + image_grids[
|
||||
:, 2:
|
||||
].prod(dim=1)
|
||||
|
||||
processed_outputs["image_num_pooled_patches"] = image_num_pooled_patches
|
||||
n_patches = processed_outputs["pixel_values"].shape[1]
|
||||
processed_outputs["image_num_patches"] = (
|
||||
processed_outputs["image_num_crops"] * n_patches
|
||||
)
|
||||
processed_outputs["input_ids"] = torch.concat(
|
||||
[bos_token_id_tensor, input_ids], dim=1
|
||||
)
|
||||
combined_outputs = dict(
|
||||
processed_outputs,
|
||||
**video_outputs,
|
||||
)
|
||||
return BatchFeature(combined_outputs)
|
||||
(
|
||||
processed_outputs["image_tokens"],
|
||||
processed_outputs["num_image_tokens"],
|
||||
) = build_flat_image_bool_length(image_grids, hf_config)
|
||||
|
||||
return BatchFeature({**processed_outputs, **all_video_outputs})
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
@@ -2338,41 +2153,65 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
img_patch_id = processor.image_patch_id
|
||||
img_col_id = processor.im_col_id
|
||||
img_start_id = processor.im_start_id
|
||||
img_end_id = processor.im_end_id
|
||||
image_use_col_tokens = processor.processor.image_use_col_tokens
|
||||
use_single_crop_col_tokens = processor.processor.use_single_crop_col_tokens
|
||||
use_single_crop_start_token = processor.processor.use_single_crop_start_token
|
||||
video_use_col_tokens = processor.processor.video_use_col_tokens
|
||||
use_frame_special_tokens = processor.processor.use_frame_special_tokens
|
||||
hf_config = self.info.get_hf_config()
|
||||
img_patch_id = hf_config.image_patch_id
|
||||
img_col_id = hf_config.image_col_id
|
||||
img_start_id = hf_config.image_start_token_id
|
||||
img_end_id = hf_config.image_end_token_id
|
||||
low_res_im_start_id = hf_config.low_res_image_start_token_id
|
||||
frame_start_id = hf_config.frame_start_token_id
|
||||
frame_end_id = hf_config.frame_end_token_id
|
||||
im_low_res_id = hf_config.image_low_res_id
|
||||
|
||||
def get_image_replacement_molmo2(item_idx: int) -> list[int]:
|
||||
emb_tok_ids = [
|
||||
img_patch_id,
|
||||
img_col_id,
|
||||
img_start_id,
|
||||
low_res_im_start_id,
|
||||
frame_start_id,
|
||||
img_end_id,
|
||||
frame_end_id,
|
||||
im_low_res_id,
|
||||
]
|
||||
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
image_use_col_tokens = processor.image_use_col_tokens
|
||||
use_single_crop_col_tokens = processor.use_single_crop_col_tokens
|
||||
use_single_crop_start_token = processor.use_single_crop_start_token
|
||||
video_use_col_tokens = processor.video_use_col_tokens
|
||||
use_frame_special_tokens = processor.use_frame_special_tokens
|
||||
|
||||
tokenizer = processor.tokenizer
|
||||
vocab = tokenizer.get_vocab()
|
||||
|
||||
image_processor = processor.image_processor
|
||||
video_processor = processor.video_processor
|
||||
|
||||
def get_image_replacement_molmo2(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image = images.get(item_idx)
|
||||
image = exif_transpose(image)
|
||||
|
||||
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
|
||||
resize_ncols, resize_nrows = self.info.get_base_grid_size(image_processor)
|
||||
if use_single_crop_col_tokens is not None:
|
||||
use_col_tokens = use_single_crop_col_tokens
|
||||
else:
|
||||
use_col_tokens = image_use_col_tokens
|
||||
if use_single_crop_start_token:
|
||||
start_id = processor.low_res_im_start_id
|
||||
start_id = low_res_im_start_id
|
||||
else:
|
||||
start_id = img_start_id
|
||||
extra_row = [img_patch_id] * resize_cols + [img_col_id] * int(
|
||||
extra_row = [img_patch_id] * resize_ncols + [img_col_id] * int(
|
||||
use_col_tokens
|
||||
)
|
||||
extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id]
|
||||
|
||||
image_size = get_image_size(image)
|
||||
|
||||
nrows, ncols = processor.get_patches_grid_size(
|
||||
ncols, nrows = self.info.get_patches_grid_size(
|
||||
image_height=image_size.height,
|
||||
image_width=image_size.width,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
|
||||
@@ -2381,21 +2220,18 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
joint = [img_start_id] + joint_row * nrows + [img_end_id]
|
||||
img_token_ids = extra_joint + joint
|
||||
|
||||
return PromptUpdateDetails.select_token_ids(
|
||||
img_token_ids,
|
||||
processor.image_token_ids,
|
||||
)
|
||||
return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids)
|
||||
|
||||
def get_video_replacement_molmo2(item_idx: int) -> list[int]:
|
||||
def get_video_replacement_molmo2(item_idx: int):
|
||||
video, metadata = mm_items["video"][item_idx]
|
||||
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
|
||||
|
||||
timestamps = self.info._get_video_second_idx(metadata, do_sample_frames)
|
||||
nrows, ncols = processor.get_base_grid_size(is_video=True)
|
||||
ncols, nrows = self.info.get_base_grid_size(video_processor)
|
||||
|
||||
if use_frame_special_tokens:
|
||||
start_id = processor.frame_start_id
|
||||
end_id = processor.frame_end_id
|
||||
start_id = frame_start_id
|
||||
end_id = frame_end_id
|
||||
else:
|
||||
start_id = img_start_id
|
||||
end_id = img_end_id
|
||||
@@ -2408,7 +2244,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
prev_space + f"{frame_time:.1f} "
|
||||
) # explicit whitespace before/after image tokens
|
||||
|
||||
img_token_ids += processor.processor.tokenizer.encode(
|
||||
img_token_ids += tokenizer.encode(
|
||||
frame_prefix,
|
||||
add_special_tokens=False,
|
||||
)
|
||||
@@ -2419,10 +2255,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
joint = [start_id] + nrows * joint_row + [end_id]
|
||||
img_token_ids += joint
|
||||
|
||||
return PromptUpdateDetails.select_token_ids(
|
||||
img_token_ids,
|
||||
processor.image_token_ids,
|
||||
)
|
||||
return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
@@ -2432,7 +2265,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
|
||||
)
|
||||
for modality, target, replacement_fn in zip(
|
||||
["image", "video"],
|
||||
[processor.image_placeholder_id, processor.video_placeholder_id],
|
||||
[vocab[IMAGE_PROMPT], vocab[VIDEO_PROMPT]],
|
||||
[get_image_replacement_molmo2, get_video_replacement_molmo2],
|
||||
)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user