diff --git a/vllm/model_executor/models/molmo2.py b/vllm/model_executor/models/molmo2.py index 18476d8ab..85f0f1932 100644 --- a/vllm/model_executor/models/molmo2.py +++ b/vllm/model_executor/models/molmo2.py @@ -3,7 +3,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields -from functools import cached_property, partial +from functools import partial from itertools import islice from typing import Annotated, Any @@ -14,14 +14,14 @@ import torch.nn.functional as F from PIL import ImageOps from PIL.Image import Image from transformers import ( + BaseImageProcessor, + BaseVideoProcessor, BatchFeature, PretrainedConfig, ProcessorMixin, - TensorType, ) from transformers.image_utils import ImageInput -from transformers.tokenization_utils_base import TextInput -from transformers.video_utils import VideoInput, VideoMetadata +from transformers.video_utils import VideoMetadata from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -1337,12 +1337,14 @@ def exif_transpose( def build_flat_image_bool_length( image_grids: torch.LongTensor, - image_patch_id: int, - low_res_image_start_id: int, - image_start_id: int, - image_col_id: int, - image_end_id: int, + hf_config: PretrainedConfig, ) -> tuple[torch.LongTensor, torch.LongTensor]: + image_patch_id = hf_config.image_patch_id + low_res_image_start_id = hf_config.low_res_image_start_token_id + image_start_id = hf_config.image_start_token_id + image_col_id = hf_config.image_col_id + image_end_id = hf_config.image_end_token_id + device = image_grids.device B = image_grids.shape[0] @@ -1401,10 +1403,12 @@ def build_flat_image_bool_length( def build_flat_video_bool_length( video_grids: torch.LongTensor, - image_patch_id: int, - frame_start_id: int, - frame_end_id: int, + hf_config: PretrainedConfig, ) -> tuple[torch.LongTensor, torch.LongTensor]: + image_patch_id = hf_config.image_patch_id + frame_start_id = hf_config.frame_start_token_id + frame_end_id = hf_config.frame_end_token_id + device = video_grids.device B = video_grids.shape[0] @@ -1439,314 +1443,6 @@ def build_flat_video_bool_length( return flat, lengths -class Molmo2ProcessorWrapper: - """ - Wraps :class:`Molmo2Processor` so that it can be called directly. - """ - - def __init__(self, processor: ProcessorMixin, hf_config: PretrainedConfig): - super().__init__() - - self.processor = processor - self.hf_config = hf_config - - @cached_property - def vocab(self) -> dict[str, int]: - return self.processor.tokenizer.vocab # type: ignore - - @cached_property - def max_crops(self) -> int: - image_processor = self.processor.image_processor # type: ignore - - max_crops = image_processor.max_crops - assert isinstance(max_crops, int) - - return max_crops - - @cached_property - def image_pooling_h(self) -> int: - image_processor = self.processor.image_processor # type: ignore - - image_pooling_h = image_processor.pooling_size[0] - assert isinstance(image_pooling_h, int) - - return image_pooling_h - - @cached_property - def image_pooling_w(self) -> int: - image_processor = self.processor.image_processor # type: ignore - - image_pooling_w = image_processor.pooling_size[1] - assert isinstance(image_pooling_w, int) - - return image_pooling_w - - @cached_property - def video_pooling_h(self) -> int: - video_processor = self.processor.video_processor # type: ignore - - video_pooling_h = video_processor.pooling_size[0] - assert isinstance(video_pooling_h, int) - - return video_pooling_h - - @cached_property - def video_pooling_w(self) -> int: - video_processor = self.processor.video_processor # type: ignore - - video_pooling_w = video_processor.pooling_size[1] - assert isinstance(video_pooling_w, int) - - return video_pooling_w - - @cached_property - def base_image_input_size(self) -> tuple[int, int]: - if getattr(self.processor, "image_processor", None) is not None: - processor = self.processor.image_processor # type: ignore - else: - processor = self.processor.video_processor # type: ignore - - base_image_input_size = (processor.size["height"], processor.size["width"]) - - return base_image_input_size - - @cached_property - def image_patch_size(self) -> int: - if getattr(self.processor, "image_processor", None) is not None: - processor = self.processor.image_processor # type: ignore - else: - processor = self.processor.video_processor # type: ignore - - image_patch_size = processor.patch_size - assert isinstance(image_patch_size, int) - - return image_patch_size - - @cached_property - def overlap_margins(self) -> tuple[int, int]: - image_processor = self.processor.image_processor # type: ignore - - left_margin, right_margin = image_processor.overlap_margins - assert isinstance(left_margin, int) - assert isinstance(right_margin, int) - - return left_margin, right_margin - - @cached_property - def bos_token(self) -> str: - return self.processor.tokenizer.bos_token or self.processor.tokenizer.eos_token - - @cached_property - def image_patch_id(self) -> int: - return self.hf_config.image_patch_id - - @cached_property - def im_col_id(self) -> int: - return self.hf_config.image_col_id - - @cached_property - def im_start_id(self) -> int: - return self.hf_config.image_start_token_id - - @cached_property - def im_end_id(self) -> int: - return self.hf_config.image_end_token_id - - @cached_property - def low_res_im_start_id(self) -> int: - return self.hf_config.low_res_image_start_token_id - - @cached_property - def frame_start_id(self) -> int: - return self.hf_config.frame_start_token_id - - @cached_property - def frame_end_id(self) -> int: - return self.hf_config.frame_end_token_id - - @cached_property - def im_low_res_id(self) -> int: - return self.hf_config.image_low_res_id - - @cached_property - def image_placeholder_id(self) -> int: - return self.vocab[IMAGE_PROMPT] - - @cached_property - def video_placeholder_id(self) -> int: - return self.vocab[VIDEO_PROMPT] - - @cached_property - def image_token_ids(self) -> list[int]: - return [ - self.image_patch_id, - self.im_col_id, - self.im_start_id, - self.low_res_im_start_id, - self.frame_start_id, - self.im_end_id, - self.frame_end_id, - self.im_low_res_id, - ] - - def select_tiling( - self, - *, - image_height: int, - image_width: int, - ) -> tuple[int, int]: - max_crops = self.max_crops - left_margin, right_margin = self.overlap_margins - base_image_input_size = self.base_image_input_size - base_image_input_d = self.image_patch_size - - total_margin_pixels = base_image_input_d * (right_margin + left_margin) - crop_patches = base_image_input_size[0] // base_image_input_d - crop_window_patches = crop_patches - (right_margin + left_margin) - crop_window_size = crop_window_patches * base_image_input_d - tiling_h, tiling_w = select_tiling( - height=image_height - total_margin_pixels, - width=image_width - total_margin_pixels, - patch_size=crop_window_size, - max_num_patches=max_crops, - ) - - return tiling_h, tiling_w - - def get_base_grid_size(self, is_video: bool) -> tuple[int, int]: - base_image_input_size = self.base_image_input_size - - return get_patches_grid_size( - image_h=base_image_input_size[0], - image_w=base_image_input_size[1], - patch_size=self.image_patch_size, - pool_h=self.video_pooling_h if is_video else self.image_pooling_h, - pool_w=self.video_pooling_w if is_video else self.image_pooling_w, - ) - - def get_patches_grid_size( - self, - *, - image_height: int, - image_width: int, - ) -> tuple[int, int]: - left_margin, right_margin = self.overlap_margins - base_image_input_size = self.base_image_input_size - base_image_input_d = self.image_patch_size - - total_margin_pixels = base_image_input_d * (right_margin + left_margin) - crop_patches = base_image_input_size[0] // base_image_input_d - crop_window_patches = crop_patches - (right_margin + left_margin) - crop_window_size = crop_window_patches * base_image_input_d - - tiling_h, tiling_w = self.select_tiling( - image_height=image_height, - image_width=image_width, - ) - - h, w = [ - tiling_h * crop_window_size + total_margin_pixels, - tiling_w * crop_window_size + total_margin_pixels, - ] - nrows, ncols = get_patches_grid_size( - image_h=h, - image_w=w, - patch_size=base_image_input_d, - pool_h=self.image_pooling_h, - pool_w=self.image_pooling_w, - ) - - return nrows, ncols - - def __call__( - self, - text: TextInput | list[TextInput] | None = None, - images: ImageInput | None = None, - videos: VideoInput | None = None, - return_tensors: str | TensorType = None, - **kwargs: object, - ) -> BatchFeature: - inputs = [text] - images = exif_transpose(images) - if getattr(self.processor, "image_processor", None) is not None: - inputs.append(images) - if getattr(self.processor, "video_processor", None) is not None: - inputs.append(videos) - outputs = self.processor( # type: ignore - *inputs, - return_tensors=return_tensors, - **kwargs, - ) - - # revert insert bos token - if outputs["input_ids"][0, 0] == self.vocab[self.bos_token]: - outputs["input_ids"] = outputs["input_ids"][:, 1:] - - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - - if videos is None: - videos = [] - if not isinstance(videos, list): - videos = [videos] - - assert len(videos) in {0, 1}, "At most one video is supported for Molmo2" - - _attention_mask: torch.Tensor = outputs.pop("attention_mask") - _token_type_ids: torch.Tensor = outputs.pop("token_type_ids", None) - - if len(images) > 0: - # For each image: tiling_h * tiling_w + global view - num_crops = [] - for image in images: - image_size = get_image_size(image) - tiling = self.select_tiling( - image_height=image_size.height, - image_width=image_size.width, - ) - num_crops.append(np.prod(tiling) + 1) - - assert sum(num_crops) == len(outputs["pixel_values"]) - assert sum(num_crops) == outputs["image_num_crops"].sum().item() - image_grids: torch.Tensor = outputs.pop("image_grids") - image_num_pooled_patches: torch.Tensor = image_grids[:, :2].prod( - dim=1 - ) + image_grids[:, 2:].prod(dim=1) - outputs["image_num_pooled_patches"] = image_num_pooled_patches - n_patches = outputs["pixel_values"].shape[1] - outputs["image_num_patches"] = outputs["image_num_crops"] * n_patches - image_tokens, num_image_tokens = build_flat_image_bool_length( - image_grids, - self.image_patch_id, - self.low_res_im_start_id, - self.im_start_id, - self.im_col_id, - self.im_end_id, - ) - outputs["image_tokens"] = image_tokens - outputs["num_image_tokens"] = num_image_tokens - - if len(videos) > 0: - video_grids: torch.Tensor = outputs.pop("video_grids") - assert video_grids[:, 0].sum() == len(outputs["pixel_values_videos"]) - outputs["video_num_crops"] = video_grids[:, 0] - outputs["video_num_pooled_patches"] = video_grids.prod(dim=1) - n_patches = outputs["pixel_values_videos"].shape[1] - outputs["video_num_patches"] = outputs["video_num_crops"] * n_patches - video_tokens, num_video_tokens = build_flat_video_bool_length( - video_grids, - self.image_patch_id, - self.frame_start_id, - self.frame_end_id, - ) - outputs["video_tokens"] = video_tokens - outputs["num_video_tokens"] = num_video_tokens - - return BatchFeature(outputs) - - def get_candidate_target_fps( video_fps: int | float, sampling_fps: int | float, @@ -1856,36 +1552,101 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): expected_hidden_size=self._get_expected_hidden_size(), ) - def get_hf_processor(self, **kwargs: object) -> Molmo2ProcessorWrapper: - processor = self.ctx.get_hf_processor(**kwargs) - hf_config = self.ctx.get_hf_config() - return Molmo2ProcessorWrapper(processor, hf_config) - def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None, "video": 1} + def select_tiling( + self, + *, + image_width: int, + image_height: int, + image_processor: BaseImageProcessor, + ) -> tuple[int, int]: + max_crops = image_processor.max_crops + left_margin, right_margin = image_processor.overlap_margins + base_image_input_d = image_processor.patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = image_processor.size["height"] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + tiling_h, tiling_w = select_tiling( + height=image_height - total_margin_pixels, + width=image_width - total_margin_pixels, + patch_size=crop_window_size, + max_num_patches=max_crops, + ) + + return tiling_w, tiling_h + + def get_base_grid_size( + self, + image_processor: BaseImageProcessor | BaseVideoProcessor, + ) -> tuple[int, int]: + nrows, ncols = get_patches_grid_size( + image_h=image_processor.size["height"], + image_w=image_processor.size["width"], + patch_size=image_processor.patch_size, + pool_h=image_processor.pooling_size[0], + pool_w=image_processor.pooling_size[1], + ) + + return ncols, nrows + + def get_patches_grid_size( + self, + *, + image_width: int, + image_height: int, + image_processor: BaseImageProcessor, + ) -> tuple[int, int]: + left_margin, right_margin = image_processor.overlap_margins + base_image_input_d = image_processor.patch_size + + total_margin_pixels = base_image_input_d * (right_margin + left_margin) + crop_patches = image_processor.size["height"] // base_image_input_d + crop_window_patches = crop_patches - (right_margin + left_margin) + crop_window_size = crop_window_patches * base_image_input_d + + tiling_w, tiling_h = self.select_tiling( + image_height=image_height, + image_width=image_width, + image_processor=image_processor, + ) + + nrows, ncols = get_patches_grid_size( + image_h=tiling_h * crop_window_size + total_margin_pixels, + image_w=tiling_w * crop_window_size + total_margin_pixels, + patch_size=base_image_input_d, + pool_h=image_processor.pooling_size[0], + pool_w=image_processor.pooling_size[1], + ) + + return ncols, nrows + def get_num_image_tokens( self, *, image_height: int, image_width: int, - processor: Molmo2ProcessorWrapper, + processor: ProcessorMixin, ) -> int: - hf_processor = processor.processor + image_processor = processor.image_processor - resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False) + resize_ncols, resize_nrows = self.get_base_grid_size(image_processor) # start/end tokens + image patch token + col tokens - if hf_processor.use_single_crop_col_tokens is not None: - use_col_tokens = hf_processor.use_single_crop_col_tokens + if processor.use_single_crop_col_tokens is not None: + use_col_tokens = processor.use_single_crop_col_tokens else: - use_col_tokens = hf_processor.image_use_col_tokens - extra = 2 + resize_nrows * (resize_cols + int(use_col_tokens)) - overlap_nrows, overlap_ncols = processor.get_patches_grid_size( + use_col_tokens = processor.image_use_col_tokens + extra = 2 + resize_nrows * (resize_ncols + int(use_col_tokens)) + overlap_ncols, overlap_nrows = self.get_patches_grid_size( image_height=image_height, image_width=image_width, + image_processor=image_processor, ) joint = 2 + overlap_nrows * ( - overlap_ncols + int(hf_processor.image_use_col_tokens) + overlap_ncols + int(processor.image_use_col_tokens) ) return extra + joint @@ -1894,28 +1655,28 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): self, *, num_frames: int, - processor: Molmo2ProcessorWrapper, + processor: ProcessorMixin, ) -> int: - resize_nrows, resize_cols = processor.get_base_grid_size(is_video=True) + video_processor = processor.video_processor + + resize_ncols, resize_nrows = self.get_base_grid_size(video_processor) # start/end tokens - extra = 2 + resize_nrows * ( - resize_cols + int(processor.processor.video_use_col_tokens) - ) + extra = 2 + resize_nrows * (resize_ncols + int(processor.video_use_col_tokens)) return num_frames * extra def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() + image_processor = processor.image_processor - left_margin, right_margin = processor.overlap_margins - base_image_input_size = processor.base_image_input_size - base_image_input_d = processor.image_patch_size + left_margin, right_margin = image_processor.overlap_margins + base_image_input_d = image_processor.patch_size total_margin_pixels = base_image_input_d * (right_margin + left_margin) - crop_patches = base_image_input_size[0] // base_image_input_d + crop_patches = image_processor.size["height"] // base_image_input_d crop_window_patches = crop_patches - (right_margin + left_margin) crop_window_size = crop_window_patches * base_image_input_d - tilings = get_candidate_tilings(processor.max_crops) + tilings = get_candidate_tilings(image_processor.max_crops) largest_feature_size, largest_feature_pinpoint = 0, None for hr, wr in tilings: @@ -1939,7 +1700,7 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): def _get_max_video_frames( self, max_tokens: int, - processor: Molmo2ProcessorWrapper, + processor: ProcessorMixin, ) -> int: num_tokens_per_frame = self.get_num_video_tokens( num_frames=1, @@ -1954,7 +1715,8 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): mm_counts: Mapping[str, int], ) -> int: processor = self.get_hf_processor() - video_processor = processor.processor.video_processor + video_processor = processor.video_processor + num_frames = video_processor.num_frames max_videos = mm_counts.get("video", 0) max_total_frames = self._get_max_video_frames(seq_len, processor) @@ -2030,7 +1792,9 @@ class Molmo2ProcessingInfo(BaseProcessingInfo): metadata: dict[str, Any], do_sample_frames: bool | None = None, ) -> list[float]: - video_processor = self.get_hf_processor().processor.video_processor + processor = self.get_hf_processor() + video_processor = processor.video_processor + # metadata["fps"] refers to the true fps of the input video. video_fps = metadata["fps"] frames_indices = metadata.get("frames_indices") @@ -2104,7 +1868,7 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]): if num_videos > 0: processor = self.info.get_hf_processor() - base_image_input_size = processor.base_image_input_size + video_size = processor.video_processor.size target_num_frames = self.info.get_num_frames_with_most_features( seq_len, mm_counts ) @@ -2131,8 +1895,8 @@ class Molmo2DummyInputsBuilder(BaseDummyInputsBuilder[Molmo2ProcessingInfo]): target_num_frames = min(target_num_frames, num_frames_override) dummy_videos = self._get_dummy_videos( - width=base_image_input_size[1], - height=base_image_input_size[0], + width=video_size["width"], + height=video_size["height"], num_frames=target_num_frames, num_videos=num_videos, ) @@ -2174,10 +1938,10 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): prompt_tokens: list[int], ) -> list[int]: processor = self.info.get_hf_processor() - tokenizer = processor.processor.tokenizer + tokenizer = processor.tokenizer bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id - if len(prompt_tokens) > 0 and prompt_tokens[0] != bos_token_id: + if len(prompt_tokens) == 0 or prompt_tokens[0] != bos_token_id: # Prepend the bos token to the prompt tokens prompt_tokens = [bos_token_id] + prompt_tokens @@ -2191,9 +1955,26 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): tok_kwargs: Mapping[str, object], ) -> BatchFeature: mm_data = dict(mm_data) - processor = self.info.get_hf_processor(**mm_kwargs) + + hf_config = self.info.get_hf_config() + hf_processor = self.info.get_hf_processor(**mm_kwargs) + + def patched_call(text=None, images=None, videos=None, **kwargs) -> BatchFeature: + res = hf_processor(text=text, images=images, videos=videos, **kwargs) + + # Molmo2Processor.insert_bos results in float outputs + # if the input text is empty + if not text: + res["input_ids"] = res["input_ids"].long() + + return res + + tokenizer = hf_processor.tokenizer + image_processor = hf_processor.image_processor if videos := mm_data.pop("videos", []): + bos_token_id = tokenizer.bos_token_id or tokenizer.eos_token_id + pixel_values_videos_lst = [] video_token_pooling_lst = [] video_num_crops_lst = [] @@ -2228,18 +2009,32 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): video_mm_data["videos"] = [[video_array]] video_mm_data["video_metadata"] = [[metadata]] - video_outputs = super()._call_hf_processor( - prompt=VIDEO_PROMPT, - mm_data=video_mm_data, - mm_kwargs=video_mm_kwargs, - tok_kwargs=tok_kwargs, + video_outputs = self.info.ctx.call_hf_processor( + patched_call, + dict(text=VIDEO_PROMPT, **video_mm_data), + dict(**video_mm_kwargs, **tok_kwargs), ) + input_ids = video_outputs.pop("input_ids") - video_string = processor.processor.tokenizer.batch_decode(input_ids)[0] - prompt = prompt.replace( - VIDEO_PROMPT, - video_string, - 1, + if input_ids[0, 0] == bos_token_id: + input_ids = input_ids[:, 1:] + + video_string = tokenizer.batch_decode(input_ids)[0] + prompt = prompt.replace(VIDEO_PROMPT, video_string, 1) + + video_grids = video_outputs.pop("video_grids") + assert video_grids[:, 0].sum() == len( + video_outputs["pixel_values_videos"] + ) + + video_outputs["video_num_crops"] = video_grids[:, 0] + video_outputs["video_num_pooled_patches"] = video_grids.prod(dim=1) + n_patches = video_outputs["pixel_values_videos"].shape[1] + video_outputs["video_num_patches"] = ( + video_outputs["video_num_crops"] * n_patches + ) + (video_outputs["video_tokens"], video_outputs["num_video_tokens"]) = ( + build_flat_video_bool_length(video_grids, hf_config) ) pixel_values_videos_lst.append(video_outputs["pixel_values_videos"]) @@ -2252,7 +2047,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): video_tokens_lst.append(video_outputs["video_tokens"]) num_video_tokens_lst.append(video_outputs["num_video_tokens"]) - video_outputs = dict( + all_video_outputs = dict( pixel_values_videos=torch.cat(pixel_values_videos_lst), video_token_pooling=torch.cat(video_token_pooling_lst), video_num_crops=torch.cat(video_num_crops_lst), @@ -2262,30 +2057,50 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): num_video_tokens=torch.cat(num_video_tokens_lst), ) else: - video_outputs = dict() + all_video_outputs = dict() - processed_outputs = super()._call_hf_processor( - prompt=prompt, - mm_data=mm_data, - mm_kwargs=mm_kwargs, - tok_kwargs=tok_kwargs, + processed_outputs = self.info.ctx.call_hf_processor( + patched_call, + dict(text=prompt, **mm_data), + dict(**mm_kwargs, **tok_kwargs), ) - bos_token_id = processor.vocab[processor.bos_token] - input_ids = processed_outputs["input_ids"] - # add bos token back to prompt start - if input_ids.numel() > 0 and input_ids[0, 0] != bos_token_id: - bos_token_id_tensor = torch.tensor( - [[bos_token_id]], device=input_ids.device, dtype=input_ids.dtype + if (images := mm_data.get("images")) is not None: + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + + # For each image: tiling_h * tiling_w + global view + tilings = [ + self.info.select_tiling( + image_width=image_size.width, + image_height=image_size.height, + image_processor=image_processor, + ) + for image_size in image_sizes + ] + num_crops = torch.tensor(tilings).prod(-1) + 1 + assert sum(num_crops) == len(processed_outputs["pixel_values"]) + assert sum(num_crops) == processed_outputs["image_num_crops"].sum().item() + + image_grids = processed_outputs.pop("image_grids") + image_num_pooled_patches = image_grids[:, :2].prod(dim=1) + image_grids[ + :, 2: + ].prod(dim=1) + + processed_outputs["image_num_pooled_patches"] = image_num_pooled_patches + n_patches = processed_outputs["pixel_values"].shape[1] + processed_outputs["image_num_patches"] = ( + processed_outputs["image_num_crops"] * n_patches ) - processed_outputs["input_ids"] = torch.concat( - [bos_token_id_tensor, input_ids], dim=1 - ) - combined_outputs = dict( - processed_outputs, - **video_outputs, - ) - return BatchFeature(combined_outputs) + ( + processed_outputs["image_tokens"], + processed_outputs["num_image_tokens"], + ) = build_flat_image_bool_length(image_grids, hf_config) + + return BatchFeature({**processed_outputs, **all_video_outputs}) def _get_mm_fields_config( self, @@ -2338,41 +2153,65 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: - processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - img_patch_id = processor.image_patch_id - img_col_id = processor.im_col_id - img_start_id = processor.im_start_id - img_end_id = processor.im_end_id - image_use_col_tokens = processor.processor.image_use_col_tokens - use_single_crop_col_tokens = processor.processor.use_single_crop_col_tokens - use_single_crop_start_token = processor.processor.use_single_crop_start_token - video_use_col_tokens = processor.processor.video_use_col_tokens - use_frame_special_tokens = processor.processor.use_frame_special_tokens + hf_config = self.info.get_hf_config() + img_patch_id = hf_config.image_patch_id + img_col_id = hf_config.image_col_id + img_start_id = hf_config.image_start_token_id + img_end_id = hf_config.image_end_token_id + low_res_im_start_id = hf_config.low_res_image_start_token_id + frame_start_id = hf_config.frame_start_token_id + frame_end_id = hf_config.frame_end_token_id + im_low_res_id = hf_config.image_low_res_id - def get_image_replacement_molmo2(item_idx: int) -> list[int]: + emb_tok_ids = [ + img_patch_id, + img_col_id, + img_start_id, + low_res_im_start_id, + frame_start_id, + img_end_id, + frame_end_id, + im_low_res_id, + ] + + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + image_use_col_tokens = processor.image_use_col_tokens + use_single_crop_col_tokens = processor.use_single_crop_col_tokens + use_single_crop_start_token = processor.use_single_crop_start_token + video_use_col_tokens = processor.video_use_col_tokens + use_frame_special_tokens = processor.use_frame_special_tokens + + tokenizer = processor.tokenizer + vocab = tokenizer.get_vocab() + + image_processor = processor.image_processor + video_processor = processor.video_processor + + def get_image_replacement_molmo2(item_idx: int): images = mm_items.get_items("image", ImageProcessorItems) image = images.get(item_idx) image = exif_transpose(image) - resize_nrows, resize_cols = processor.get_base_grid_size(is_video=False) + resize_ncols, resize_nrows = self.info.get_base_grid_size(image_processor) if use_single_crop_col_tokens is not None: use_col_tokens = use_single_crop_col_tokens else: use_col_tokens = image_use_col_tokens if use_single_crop_start_token: - start_id = processor.low_res_im_start_id + start_id = low_res_im_start_id else: start_id = img_start_id - extra_row = [img_patch_id] * resize_cols + [img_col_id] * int( + extra_row = [img_patch_id] * resize_ncols + [img_col_id] * int( use_col_tokens ) extra_joint = [start_id] + extra_row * resize_nrows + [img_end_id] image_size = get_image_size(image) - nrows, ncols = processor.get_patches_grid_size( + ncols, nrows = self.info.get_patches_grid_size( image_height=image_size.height, image_width=image_size.width, + image_processor=image_processor, ) joint_row = [img_patch_id] * ncols + [img_col_id] * int( @@ -2381,21 +2220,18 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): joint = [img_start_id] + joint_row * nrows + [img_end_id] img_token_ids = extra_joint + joint - return PromptUpdateDetails.select_token_ids( - img_token_ids, - processor.image_token_ids, - ) + return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids) - def get_video_replacement_molmo2(item_idx: int) -> list[int]: + def get_video_replacement_molmo2(item_idx: int): video, metadata = mm_items["video"][item_idx] do_sample_frames = hf_processor_mm_kwargs.get("do_sample_frames") timestamps = self.info._get_video_second_idx(metadata, do_sample_frames) - nrows, ncols = processor.get_base_grid_size(is_video=True) + ncols, nrows = self.info.get_base_grid_size(video_processor) if use_frame_special_tokens: - start_id = processor.frame_start_id - end_id = processor.frame_end_id + start_id = frame_start_id + end_id = frame_end_id else: start_id = img_start_id end_id = img_end_id @@ -2408,7 +2244,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): prev_space + f"{frame_time:.1f} " ) # explicit whitespace before/after image tokens - img_token_ids += processor.processor.tokenizer.encode( + img_token_ids += tokenizer.encode( frame_prefix, add_special_tokens=False, ) @@ -2419,10 +2255,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): joint = [start_id] + nrows * joint_row + [end_id] img_token_ids += joint - return PromptUpdateDetails.select_token_ids( - img_token_ids, - processor.image_token_ids, - ) + return PromptUpdateDetails.select_token_ids(img_token_ids, emb_tok_ids) return [ PromptReplacement( @@ -2432,7 +2265,7 @@ class Molmo2MultiModalProcessor(BaseMultiModalProcessor[Molmo2ProcessingInfo]): ) for modality, target, replacement_fn in zip( ["image", "video"], - [processor.image_placeholder_id, processor.video_placeholder_id], + [vocab[IMAGE_PROMPT], vocab[VIDEO_PROMPT]], [get_image_replacement_molmo2, get_video_replacement_molmo2], ) ]