[Refactor] Remove Molmo2 processor wrapper (#36667)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-11 18:07:20 +08:00
committed by GitHub
parent 4286cc5ec2
commit 646b85544b

View File

@@ -3,7 +3,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property, partial
from functools import partial
from itertools import islice
from typing import Annotated, Any
@@ -14,14 +14,14 @@ import torch.nn.functional as F
from PIL import ImageOps
from PIL.Image import Image
from transformers import (
BaseImageProcessor,
BaseVideoProcessor,
BatchFeature,
PretrainedConfig,
ProcessorMixin,
TensorType,
)
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from transformers.video_utils import VideoInput, VideoMetadata
from transformers.video_utils import VideoMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -1337,12 +1337,14 @@ def exif_transpose(
def build_flat_image_bool_length(
image_grids: torch.LongTensor,
image_patch_id: int,
low_res_image_start_id: int,
image_start_id: int,
image_col_id: int,
image_end_id: int,
hf_config: PretrainedConfig,
) -> tuple[torch.LongTensor, torch.LongTensor]:
image_patch_id = hf_config.image_patch_id
low_res_image_start_id = hf_config.low_res_image_start_token_id
image_start_id = hf_config.image_start_token_id
image_col_id = hf_config.image_col_id
image_end_id = hf_config.image_end_token_id
device = image_grids.device
B = image_grids.shape[0]
@@ -1401,10 +1403,12 @@ def build_flat_image_bool_length(
def build_flat_video_bool_length(
video_grids: torch.LongTensor,
image_patch_id: int,
frame_start_id: int,
frame_end_id: int,
hf_config: PretrainedConfig,
) -> tuple[torch.LongTensor, torch.LongTensor]:
image_patch_id = hf_config.image_patch_id
frame_start_id = hf_config.frame_start_token_id
frame_end_id = hf_config.frame_end_token_id
device = video_grids.device
B = video_grids.shape[0]
@@ -1439,314 +1443,6 @@ def build_flat_video_bool_length(
return flat, lengths
class Molmo2ProcessorWrapper:
"""
Wraps :class:`Molmo2Processor` so that it can be called directly.
"""
def __init__(self, processor: ProcessorMixin, hf_config: PretrainedConfig):
super().__init__()
self.processor = processor
self.hf_config = hf_config
@cached_property
def vocab(self) -> dict[str, int]:
return self.processor.tokenizer.vocab # type: ignore
@cached_property
def max_crops(self) -> int:
image_processor = self.processor.image_processor # type: ignore
max_crops = image_processor.max_crops
assert isinstance(max_crops, int)
return max_crops
@cached_property
def image_pooling_h(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_pooling_h = image_processor.pooling_size[0]
assert isinstance(image_pooling_h, int)
return image_pooling_h
@cached_property
def image_pooling_w(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_pooling_w = image_processor.pooling_size[1]
assert isinstance(image_pooling_w, int)
return image_pooling_w
@cached_property
def video_pooling_h(self) -> int:
video_processor = self.processor.video_processor # type: ignore
video_pooling_h = video_processor.pooling_size[0]
assert isinstance(video_pooling_h, int)
return video_pooling_h
@cached_property
def video_pooling_w(self) -> int:
video_processor = self.processor.video_processor # type: ignore
video_pooling_w = video_processor.pooling_size[1]
assert isinstance(video_pooling_w, int)
return video_pooling_w
@cached_property
def base_image_input_size(self) -> tuple[int, int]:
if getattr(self.processor, "image_processor", None) is not None:
processor = self.processor.image_processor # type: ignore
else:
processor = self.processor.video_processor # type: ignore
base_image_input_size = (processor.size["height"], processor.size["width"])
return base_image_input_size
@cached_property
def image_patch_size(self) -> int:
if getattr(self.processor, "image_processor", None) is not None:
processor = self.processor.image_processor # type: ignore
else:
processor = self.processor.video_processor # type: ignore
image_patch_size = processor.patch_size
assert isinstance(image_patch_size, int)
return image_patch_size
@cached_property
def overlap_margins(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
left_margin, right_margin = image_processor.overlap_margins
assert isinstance(left_margin, int)
assert isinstance(right_margin, int)
return left_margin, right_margin
@cached_property
def bos_token(self) -> str:
return self.processor.tokenizer.bos_token or self.processor.tokenizer.eos_token
@cached_property
def image_patch_id(self) -> int:
return self.hf_config.image_patch_id
@cached_property
def im_col_id(self) -> int:
return self.hf_config.image_col_id
@cached_property
def im_start_id(self) -> int:
return self.hf_config.image_start_token_id
@cached_property
def im_end_id(self) -> int:
return self.hf_config.image_end_token_id
@cached_property
def low_res_im_start_id(self) -> int:
return self.hf_config.low_res_image_start_token_id
@cached_property
def frame_start_id(self) -> int:
return self.hf_config.frame_start_token_id
@cached_property
def frame_end_id(self) -> int:
return self.hf_config.frame_end_token_id
@cached_property
def im_low_res_id(self) -> int:
return self.hf_config.image_low_res_id
@cached_property
def image_placeholder_id(self) -> int:
return self.vocab[IMAGE_PROMPT]
@cached_property
def video_placeholder_id(self) -> int:
return self.vocab[VIDEO_PROMPT]
@cached_property
def image_token_ids(self) -> list[int]:
return [
self.image_patch_id,
self.im_col_id,
self.im_start_id,
self.low_res_im_start_id,
self.frame_start_id,
self.im_end_id,
self.frame_end_id,
self.im_low_res_id,
]
def select_tiling(
self,
*,
image_height: int,
image_width: int,
) -> tuple[int, int]:
max_crops = self.max_crops
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_h, tiling_w = select_tiling(
height=image_height - total_margin_pixels,
width=image_width - total_margin_pixels,
patch_size=crop_window_size,
max_num_patches=max_crops,
)
return tiling_h, tiling_w
def get_base_grid_size(self, is_video: bool) -> tuple[int, int]:
base_image_input_size = self.base_image_input_size
return get_patches_grid_size(
image_h=base_image_input_size[0],
image_w=base_image_input_size[1],
patch_size=self.image_patch_size,
pool_h=self.video_pooling_h if is_video else self.image_pooling_h,
pool_w=self.video_pooling_w if is_video else self.image_pooling_w,
)
def get_patches_grid_size(
self,
*,
image_height: int,
image_width: int,
) -> tuple[int, int]:
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_h, tiling_w = self.select_tiling(
image_height=image_height,
image_width=image_width,
)
h, w = [
tiling_h * crop_window_size + total_margin_pixels,
tiling_w * crop_window_size + total_margin_pixels,
]
nrows, ncols = get_patches_grid_size(
image_h=h,
image_w=w,
patch_size=base_image_input_d,
pool_h=self.image_pooling_h,
pool_w=self.image_pooling_w,
)
return nrows, ncols
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | None = None,
videos: VideoInput | None = None,
return_tensors: str | TensorType = None,
**kwargs: object,
) -> BatchFeature:
inputs = [text]
images = exif_transpose(images)
if getattr(self.processor, "image_processor", None) is not None:
inputs.append(images)
if getattr(self.processor, "video_processor", None) is not None:
inputs.append(videos)
outputs = self.processor( # type: ignore
*inputs,
return_tensors=return_tensors,
**kwargs,
)
# revert insert bos token
if outputs["input_ids"][0, 0] == self.vocab[self.bos_token]:
outputs["input_ids"] = outputs["input_ids"][:, 1:]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if videos is None:
videos = []
if not isinstance(videos, list):
videos = [videos]
assert len(videos) in {0, 1}, "At most one video is supported for Molmo2"
_attention_mask: torch.Tensor = outputs.pop("attention_mask")
_token_type_ids: torch.Tensor = outputs.pop("token_type_ids", None)
if len(images) > 0:
# For each image: tiling_h * tiling_w + global view
num_crops = []
for image in images:
image_size = get_image_size(image)
tiling = self.select_tiling(
image_height=image_size.height,
image_width=image_size.width,
)
num_crops.append(np.prod(tiling) + 1)
assert sum(num_crops) == len(outputs["pixel_values"])
assert sum(num_crops) == outputs["image_num_crops"].sum().item()
image_grids: torch.Tensor = outputs.pop("image_grids")
image_num_pooled_patches: torch.Tensor = image_grids[:, :2].prod(
dim=1
) + image_grids[:, 2:].prod(dim=1)
outputs["image_num_pooled_patches"] = image_num_pooled_patches
n_patches = outputs["pixel_values"].shape[1]
outputs["image_num_patches"] = outputs["image_num_crops"] * n_patches
image_tokens, num_image_tokens = build_flat_image_bool_length(
image_grids,
self.image_patch_id,
self.low_res_im_start_id,
self.im_start_id,
self.im_col_id,
self.im_end_id,
)
outputs["image_tokens"] = image_tokens
outputs["num_image_tokens"] = num_image_tokens
if len(videos) > 0:
video_grids: torch.Tensor = outputs.pop("video_grids")
assert video_grids[:, 0].sum() == len(outputs["pixel_values_videos"])
outputs["video_num_crops"] = video_grids[:, 0]
outputs["video_num_pooled_patches"] = video_grids.prod(dim=1)
n_patches = outputs["pixel_values_videos"].shape[1]
outputs["video_num_patches"] = outputs["video_num_crops"] * n_patches
video_tokens, num_video_tokens = build_flat_video_bool_length(
video_grids,
self.image_patch_id,
self.frame_start_id,
self.frame_end_id,
)
outputs["video_tokens"] = video_tokens
outputs["num_video_tokens"] = num_video_tokens
return BatchFeature(outputs)
def get_candidate_target_fps(
video_fps: int | float,
sampling_fps: int | float,
@@ -1856,36 +1552,101 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
expected_hidden_size=self._get_expected_hidden_size(),
)
def get_hf_processor(self, **kwargs: object) -> Molmo2ProcessorWrapper:
processor = self.ctx.get_hf_processor(**kwargs)
hf_config = self.ctx.get_hf_config()
return Molmo2ProcessorWrapper(processor, hf_config)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": 1}
def select_tiling(
self,
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
max_crops = image_processor.max_crops
left_margin, right_margin = image_processor.overlap_margins
base_image_input_d = image_processor.patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = image_processor.size["height"] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_h, tiling_w = select_tiling(
height=image_height - total_margin_pixels,
width=image_width - total_margin_pixels,
patch_size=crop_window_size,
max_num_patches=max_crops,
)
return tiling_w, tiling_h
def get_base_grid_size(
self,
image_processor: BaseImageProcessor | BaseVideoProcessor,
) -> tuple[int, int]:
nrows, ncols = get_patches_grid_size(
image_h=image_processor.size["height"],
image_w=image_processor.size["width"],
patch_size=image_processor.patch_size,
pool_h=image_processor.pooling_size[0],
pool_w=image_processor.pooling_size[1],
)
return ncols, nrows
def get_patches_grid_size(
self,
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
left_margin, right_margin = image_processor.overlap_margins
base_image_input_d = image_processor.patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = image_processor.size["height"] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tiling_w, tiling_h = self.select_tiling(
image_height=image_height,
image_width=image_width,
image_processor=image_processor,
)
nrows, ncols = get_patches_grid_size(
image_h=tiling_h * crop_window_size + total_margin_pixels,
image_w=tiling_w * crop_window_size + total_margin_pixels,
patch_size=base_image_input_d,
pool_h=image_processor.pooling_size[0],
pool_w=image_processor.pooling_size[1],
)
return ncols, nrows
def get_num_image_tokens(
self,
*,
image_height: int,
image_width: int,
processor: Molmo2ProcessorWrapper,
processor: ProcessorMixin,
) -> int:
hf_processor = processor.processor
image_processor = processor.image_processor
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
resize_ncols, resize_nrows = self.get_base_grid_size(image_processor)
# start/end tokens + image patch token + col tokens
if hf_processor.use_single_crop_col_tokens is not None:
use_col_tokens = hf_processor.use_single_crop_col_tokens
if processor.use_single_crop_col_tokens is not None:
use_col_tokens = processor.use_single_crop_col_tokens
else:
use_col_tokens = hf_processor.image_use_col_tokens
extra = 2 + resize_nrows * (resize_cols + int(use_col_tokens))
overlap_nrows, overlap_ncols = processor.get_patches_grid_size(
use_col_tokens = processor.image_use_col_tokens
extra = 2 + resize_nrows * (resize_ncols + int(use_col_tokens))
overlap_ncols, overlap_nrows = self.get_patches_grid_size(
image_height=image_height,
image_width=image_width,
image_processor=image_processor,
)
joint = 2 + overlap_nrows * (
overlap_ncols + int(hf_processor.image_use_col_tokens)
overlap_ncols + int(processor.image_use_col_tokens)
)
return extra + joint
@@ -1894,28 +1655,28 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
self,
*,
num_frames: int,
processor: Molmo2ProcessorWrapper,
processor: ProcessorMixin,
) -> int:
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True)
video_processor = processor.video_processor
resize_ncols, resize_nrows = self.get_base_grid_size(video_processor)
# start/end tokens
extra = 2 + resize_nrows * (
resize_cols + int(processor.processor.video_use_col_tokens)
)
extra = 2 + resize_nrows * (resize_ncols + int(processor.video_use_col_tokens))
return num_frames * extra
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor = processor.image_processor
left_margin, right_margin = processor.overlap_margins
base_image_input_size = processor.base_image_input_size
base_image_input_d = processor.image_patch_size
left_margin, right_margin = image_processor.overlap_margins
base_image_input_d = image_processor.patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
crop_patches = image_processor.size["height"] // base_image_input_d
crop_window_patches = crop_patches - (right_margin + left_margin)
crop_window_size = crop_window_patches * base_image_input_d
tilings = get_candidate_tilings(processor.max_crops)
tilings = get_candidate_tilings(image_processor.max_crops)
largest_feature_size, largest_feature_pinpoint = 0, None
for hr, wr in tilings:
@@ -1939,7 +1700,7 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
def _get_max_video_frames(
self,
max_tokens: int,
processor: Molmo2ProcessorWrapper,
processor: ProcessorMixin,
) -> int:
num_tokens_per_frame = self.get_num_video_tokens(
num_frames=1,
@@ -1954,7 +1715,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
mm_counts: Mapping[str, int],
) -> int:
processor = self.get_hf_processor()
video_processor = processor.processor.video_processor
video_processor = processor.video_processor
num_frames = video_processor.num_frames
max_videos = mm_counts.get("video", 0)
max_total_frames = self._get_max_video_frames(seq_len, processor)
@@ -2030,7 +1792,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo):
metadata: dict[str, Any],
do_sample_frames: bool | None = None,
) -> list[float]:
video_processor = self.get_hf_processor().processor.video_processor
processor = self.get_hf_processor()
video_processor = processor.video_processor
# metadata["fps"] refers to the true fps of the input video.
video_fps = metadata["fps"]
frames_indices = metadata.get("frames_indices")
@@ -2104,7 +1868,7 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
if num_videos > 0:
processor = self.info.get_hf_processor()
base_image_input_size = processor.base_image_input_size
video_size = processor.video_processor.size
target_num_frames = self.info.get_num_frames_with_most_features(
seq_len, mm_counts
)
@@ -2131,8 +1895,8 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]):
target_num_frames = min(target_num_frames, num_frames_override)
dummy_videos = self._get_dummy_videos(
width=base_image_input_size[1],
height=base_image_input_size[0],
width=video_size["width"],
height=video_size["height"],
num_frames=target_num_frames,
num_videos=num_videos,
)
@@ -2174,10 +1938,10 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
prompt_tokens: list[int],
) -> list[int]:
processor = self.info.get_hf_processor()
tokenizer = processor.processor.tokenizer
tokenizer = processor.tokenizer
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
if len(prompt_tokens) > 0 and prompt_tokens[0] != bos_token_id:
if len(prompt_tokens) == 0 or prompt_tokens[0] != bos_token_id:
# Prepend the bos token to the prompt tokens
prompt_tokens = [bos_token_id] + prompt_tokens
@@ -2191,9 +1955,26 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
mm_data = dict(mm_data)
processor = self.info.get_hf_processor(**mm_kwargs)
hf_config = self.info.get_hf_config()
hf_processor = self.info.get_hf_processor(**mm_kwargs)
def patched_call(text=None, images=None, videos=None, **kwargs) -> BatchFeature:
res = hf_processor(text=text, images=images, videos=videos, **kwargs)
# Molmo2Processor.insert_bos results in float outputs
# if the input text is empty
if not text:
res["input_ids"] = res["input_ids"].long()
return res
tokenizer = hf_processor.tokenizer
image_processor = hf_processor.image_processor
if videos := mm_data.pop("videos", []):
bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id
pixel_values_videos_lst = []
video_token_pooling_lst = []
video_num_crops_lst = []
@@ -2228,18 +2009,32 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
video_mm_data["videos"] = [[video_array]]
video_mm_data["video_metadata"] = [[metadata]]
video_outputs = super()._call_hf_processor(
prompt=VIDEO_PROMPT,
mm_data=video_mm_data,
mm_kwargs=video_mm_kwargs,
tok_kwargs=tok_kwargs,
video_outputs = self.info.ctx.call_hf_processor(
patched_call,
dict(text=VIDEO_PROMPT, **video_mm_data),
dict(**video_mm_kwargs, **tok_kwargs),
)
input_ids = video_outputs.pop("input_ids")
video_string = processor.processor.tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(
VIDEO_PROMPT,
video_string,
1,
if input_ids[0, 0] == bos_token_id:
input_ids = input_ids[:, 1:]
video_string = tokenizer.batch_decode(input_ids)[0]
prompt = prompt.replace(VIDEO_PROMPT, video_string, 1)
video_grids = video_outputs.pop("video_grids")
assert video_grids[:, 0].sum() == len(
video_outputs["pixel_values_videos"]
)
video_outputs["video_num_crops"] = video_grids[:, 0]
video_outputs["video_num_pooled_patches"] = video_grids.prod(dim=1)
n_patches = video_outputs["pixel_values_videos"].shape[1]
video_outputs["video_num_patches"] = (
video_outputs["video_num_crops"] * n_patches
)
(video_outputs["video_tokens"], video_outputs["num_video_tokens"]) = (
build_flat_video_bool_length(video_grids, hf_config)
)
pixel_values_videos_lst.append(video_outputs["pixel_values_videos"])
@@ -2252,7 +2047,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
video_tokens_lst.append(video_outputs["video_tokens"])
num_video_tokens_lst.append(video_outputs["num_video_tokens"])
video_outputs = dict(
all_video_outputs = dict(
pixel_values_videos=torch.cat(pixel_values_videos_lst),
video_token_pooling=torch.cat(video_token_pooling_lst),
video_num_crops=torch.cat(video_num_crops_lst),
@@ -2262,30 +2057,50 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
num_video_tokens=torch.cat(num_video_tokens_lst),
)
else:
video_outputs = dict()
all_video_outputs = dict()
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
processed_outputs = self.info.ctx.call_hf_processor(
patched_call,
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
bos_token_id = processor.vocab[processor.bos_token]
input_ids = processed_outputs["input_ids"]
# add bos token back to prompt start
if input_ids.numel() > 0 and input_ids[0, 0] != bos_token_id:
bos_token_id_tensor = torch.tensor(
[[bos_token_id]], device=input_ids.device, dtype=input_ids.dtype
if (images := mm_data.get("images")) is not None:
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
# For each image: tiling_h * tiling_w + global view
tilings = [
self.info.select_tiling(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
for image_size in image_sizes
]
num_crops = torch.tensor(tilings).prod(-1) + 1
assert sum(num_crops) == len(processed_outputs["pixel_values"])
assert sum(num_crops) == processed_outputs["image_num_crops"].sum().item()
image_grids = processed_outputs.pop("image_grids")
image_num_pooled_patches = image_grids[:, :2].prod(dim=1) + image_grids[
:, 2:
].prod(dim=1)
processed_outputs["image_num_pooled_patches"] = image_num_pooled_patches
n_patches = processed_outputs["pixel_values"].shape[1]
processed_outputs["image_num_patches"] = (
processed_outputs["image_num_crops"] * n_patches
)
processed_outputs["input_ids"] = torch.concat(
[bos_token_id_tensor, input_ids], dim=1
)
combined_outputs = dict(
processed_outputs,
**video_outputs,
)
return BatchFeature(combined_outputs)
(
processed_outputs["image_tokens"],
processed_outputs["num_image_tokens"],
) = build_flat_image_bool_length(image_grids, hf_config)
return BatchFeature({**processed_outputs, **all_video_outputs})
def _get_mm_fields_config(
self,
@@ -2338,41 +2153,65 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id
img_start_id = processor.im_start_id
img_end_id = processor.im_end_id
image_use_col_tokens = processor.processor.image_use_col_tokens
use_single_crop_col_tokens = processor.processor.use_single_crop_col_tokens
use_single_crop_start_token = processor.processor.use_single_crop_start_token
video_use_col_tokens = processor.processor.video_use_col_tokens
use_frame_special_tokens = processor.processor.use_frame_special_tokens
hf_config = self.info.get_hf_config()
img_patch_id = hf_config.image_patch_id
img_col_id = hf_config.image_col_id
img_start_id = hf_config.image_start_token_id
img_end_id = hf_config.image_end_token_id
low_res_im_start_id = hf_config.low_res_image_start_token_id
frame_start_id = hf_config.frame_start_token_id
frame_end_id = hf_config.frame_end_token_id
im_low_res_id = hf_config.image_low_res_id
def get_image_replacement_molmo2(item_idx: int) -> list[int]:
emb_tok_ids = [
img_patch_id,
img_col_id,
img_start_id,
low_res_im_start_id,
frame_start_id,
img_end_id,
frame_end_id,
im_low_res_id,
]
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_use_col_tokens = processor.image_use_col_tokens
use_single_crop_col_tokens = processor.use_single_crop_col_tokens
use_single_crop_start_token = processor.use_single_crop_start_token
video_use_col_tokens = processor.video_use_col_tokens
use_frame_special_tokens = processor.use_frame_special_tokens
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_processor = processor.image_processor
video_processor = processor.video_processor
def get_image_replacement_molmo2(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image = images.get(item_idx)
image = exif_transpose(image)
resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False)
resize_ncols, resize_nrows = self.info.get_base_grid_size(image_processor)
if use_single_crop_col_tokens is not None:
use_col_tokens = use_single_crop_col_tokens
else:
use_col_tokens = image_use_col_tokens
if use_single_crop_start_token:
start_id = processor.low_res_im_start_id
start_id = low_res_im_start_id
else:
start_id = img_start_id
extra_row = [img_patch_id] * resize_cols + [img_col_id] * int(
extra_row = [img_patch_id] * resize_ncols + [img_col_id] * int(
use_col_tokens
)
extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id]
image_size = get_image_size(image)
nrows, ncols = processor.get_patches_grid_size(
ncols, nrows = self.info.get_patches_grid_size(
image_height=image_size.height,
image_width=image_size.width,
image_processor=image_processor,
)
joint_row = [img_patch_id] * ncols + [img_col_id] * int(
@@ -2381,21 +2220,18 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
joint = [img_start_id] + joint_row * nrows + [img_end_id]
img_token_ids = extra_joint + joint
return PromptUpdateDetails.select_token_ids(
img_token_ids,
processor.image_token_ids,
)
return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids)
def get_video_replacement_molmo2(item_idx: int) -> list[int]:
def get_video_replacement_molmo2(item_idx: int):
video, metadata = mm_items["video"][item_idx]
do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames")
timestamps = self.info._get_video_second_idx(metadata, do_sample_frames)
nrows, ncols = processor.get_base_grid_size(is_video=True)
ncols, nrows = self.info.get_base_grid_size(video_processor)
if use_frame_special_tokens:
start_id = processor.frame_start_id
end_id = processor.frame_end_id
start_id = frame_start_id
end_id = frame_end_id
else:
start_id = img_start_id
end_id = img_end_id
@@ -2408,7 +2244,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
prev_space + f"{frame_time:.1f} "
) # explicit whitespace before/after image tokens
img_token_ids += processor.processor.tokenizer.encode(
img_token_ids += tokenizer.encode(
frame_prefix,
add_special_tokens=False,
)
@@ -2419,10 +2255,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
joint = [start_id] + nrows * joint_row + [end_id]
img_token_ids += joint
return PromptUpdateDetails.select_token_ids(
img_token_ids,
processor.image_token_ids,
)
return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids)
return [
PromptReplacement(
@@ -2432,7 +2265,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]):
)
for modality, target, replacement_fn in zip(
["image", "video"],
[processor.image_placeholder_id, processor.video_placeholder_id],
[vocab[IMAGE_PROMPT], vocab[VIDEO_PROMPT]],
[get_image_replacement_molmo2, get_video_replacement_molmo2],
)
]