[Misc] Move processors to transformers_utils (#35953)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from typing import Annotated
|
||||
|
||||
@@ -13,9 +13,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
PretrainedConfig,
|
||||
)
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -1017,117 +1019,28 @@ def select_tiling(
|
||||
return candidate_tilings[ix]
|
||||
|
||||
|
||||
class MolmoProcessorWrapper:
|
||||
"""
|
||||
Wraps `MolmoProcessor` so that it can be called directly.
|
||||
def _as_2tuple(x: int | tuple[int, int]) -> tuple[int, int]:
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
|
||||
The original definition can be found here:
|
||||
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
|
||||
"""
|
||||
return x
|
||||
|
||||
def __init__(self, processor: ProcessorMixin):
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
return self.processor.tokenizer.vocab # type: ignore
|
||||
|
||||
@cached_property
|
||||
def max_crops(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
max_crops = image_processor.max_crops
|
||||
assert isinstance(max_crops, int)
|
||||
|
||||
return max_crops
|
||||
|
||||
@cached_property
|
||||
def base_image_input_size(self) -> tuple[int, int]:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
base_image_input_size = image_processor.base_image_input_size
|
||||
if isinstance(base_image_input_size, int):
|
||||
return base_image_input_size, base_image_input_size
|
||||
|
||||
return tuple(base_image_input_size)
|
||||
|
||||
@cached_property
|
||||
def image_patch_size(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_patch_size = image_processor.image_patch_size
|
||||
assert isinstance(image_patch_size, int)
|
||||
|
||||
return image_patch_size
|
||||
|
||||
@cached_property
|
||||
def overlap_margins(self) -> tuple[int, int]:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
assert isinstance(left_margin, int)
|
||||
assert isinstance(right_margin, int)
|
||||
|
||||
return left_margin, right_margin
|
||||
|
||||
@cached_property
|
||||
def image_token_length_w(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_token_length_w = image_processor.image_token_length_w
|
||||
assert isinstance(image_token_length_w, int)
|
||||
|
||||
return image_token_length_w
|
||||
|
||||
@cached_property
|
||||
def image_token_length_h(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_token_length_h = image_processor.image_token_length_h
|
||||
assert isinstance(image_token_length_h, int)
|
||||
|
||||
return image_token_length_h
|
||||
|
||||
@property
|
||||
def message_format(self) -> str | None:
|
||||
return "role"
|
||||
|
||||
@property
|
||||
def always_start_with_space(self) -> bool:
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def image_patch_id(self) -> int:
|
||||
return self.vocab[IMAGE_PATCH_TOKEN]
|
||||
|
||||
@cached_property
|
||||
def im_col_id(self) -> int:
|
||||
return self.vocab[IM_COL_TOKEN]
|
||||
|
||||
@cached_property
|
||||
def im_start_id(self) -> int:
|
||||
return self.vocab[IM_START_TOKEN]
|
||||
|
||||
@cached_property
|
||||
def im_end_id(self) -> int:
|
||||
return self.vocab[IM_END_TOKEN]
|
||||
|
||||
@property
|
||||
def pooling_size(self) -> int:
|
||||
return POOLING_SIZE
|
||||
class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def select_tiling(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> tuple[int, int]:
|
||||
max_crops = self.max_crops
|
||||
left_margin, right_margin = self.overlap_margins
|
||||
base_image_input_size = self.base_image_input_size
|
||||
base_image_input_d = self.image_patch_size
|
||||
max_crops = image_processor.max_crops
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
|
||||
base_image_input_d = image_processor.image_patch_size
|
||||
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper:
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> tuple[int, int]:
|
||||
left_margin, right_margin = self.overlap_margins
|
||||
base_image_input_size = self.base_image_input_size
|
||||
base_image_input_d = self.image_patch_size
|
||||
pooling_size = self.pooling_size
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
|
||||
base_image_input_d = image_processor.image_patch_size
|
||||
pooling_size = POOLING_SIZE
|
||||
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
tiling_w, tiling_h = self.select_tiling(
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
nrows, ncols = get_patches_grid_size(
|
||||
@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper:
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | list[ImageInput] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
outputs = self.processor.process( # type: ignore
|
||||
text, images, **kwargs
|
||||
)
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
input_ids: torch.Tensor = outputs.pop("input_ids")
|
||||
outputs["input_ids"] = input_ids.unsqueeze(0)
|
||||
|
||||
image_input_idx = outputs.pop("image_input_idx", None)
|
||||
if image_input_idx is not None:
|
||||
feat_is_patch = image_input_idx >= 0
|
||||
|
||||
tilings = [
|
||||
self.select_tiling(
|
||||
image_width=image.size[0],
|
||||
image_height=image.size[1],
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
# For each image: tiling_h * tiling_w + extra
|
||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||
assert num_crops.sum() == len(feat_is_patch)
|
||||
|
||||
outputs["image_input_idx"] = image_input_idx
|
||||
outputs["num_crops"] = num_crops
|
||||
outputs["img_patch_id"] = self.image_patch_id
|
||||
|
||||
return BatchFeature(outputs)
|
||||
|
||||
|
||||
class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
|
||||
processor = self.ctx.get_hf_processor(**kwargs)
|
||||
return MolmoProcessorWrapper(processor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: MolmoProcessorWrapper,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> int:
|
||||
ncols, nrows = processor.get_patches_grid_size(
|
||||
ncols, nrows = self.get_patches_grid_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
pooling_size = processor.pooling_size
|
||||
pooling_size = POOLING_SIZE
|
||||
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
image_token_length_w = image_processor.image_token_length_w
|
||||
image_token_length_h = image_processor.image_token_length_h
|
||||
|
||||
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
|
||||
extra = 2 + (image_token_length_w + 1) * image_token_length_h
|
||||
@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
image_processor = processor.image_processor
|
||||
|
||||
tilings = get_candidate_tilings(processor.max_crops)
|
||||
base_h, base_w = processor.base_image_input_size
|
||||
tilings = get_candidate_tilings(image_processor.max_crops)
|
||||
base_h, base_w = _as_2tuple(image_processor.base_image_input_size)
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for wr, hr in tilings:
|
||||
@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
processor=processor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
|
||||
|
||||
|
||||
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
processed_outputs = self.info.ctx.call_hf_processor(
|
||||
hf_processor.process,
|
||||
dict(text=prompt, **mm_data),
|
||||
dict(**mm_kwargs, **tok_kwargs),
|
||||
)
|
||||
|
||||
tokenizer = hf_processor.tokenizer
|
||||
image_patch_id = tokenizer.vocab[IMAGE_PATCH_TOKEN]
|
||||
|
||||
image_processor = hf_processor.image_processor
|
||||
|
||||
input_ids: torch.Tensor = processed_outputs.pop("input_ids")
|
||||
processed_outputs["input_ids"] = input_ids.unsqueeze(0)
|
||||
|
||||
if (images := mm_data.get("images")) is not None:
|
||||
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
|
||||
parsed_images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_sizes = [
|
||||
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||
]
|
||||
|
||||
feat_is_patch = processed_outputs["image_input_idx"] >= 0
|
||||
|
||||
tilings = [
|
||||
self.info.select_tiling(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
# For each image: tiling_h * tiling_w + extra
|
||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||
assert num_crops.sum() == len(feat_is_patch)
|
||||
|
||||
processed_outputs["num_crops"] = num_crops
|
||||
processed_outputs["img_patch_id"] = image_patch_id
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _apply_hf_processor_tokens_only(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
# The chat template is already applied to the prompt tokens
|
||||
# Use message_format="none" to avoid applying it again
|
||||
# Prepend an empty space if `always_start_with_space` is True
|
||||
tokens = processor.processor.get_tokens_input( # type: ignore
|
||||
tokens = processor.get_tokens_input(
|
||||
self.info.get_tokenizer().decode(prompt_tokens),
|
||||
message_format="none",
|
||||
always_start_with_space=processor.always_start_with_space,
|
||||
always_start_with_space=True,
|
||||
)
|
||||
|
||||
# Prepend a BOS token id to the tokens
|
||||
processed_data = self.info.ctx.call_hf_processor(
|
||||
processor, # type: ignore
|
||||
processor.process,
|
||||
dict(tokens=tokens),
|
||||
)
|
||||
(prompt_ids,) = processed_data.pop("input_ids").tolist()
|
||||
prompt_ids = processed_data.pop("input_ids").tolist()
|
||||
print(prompt_ids, len(prompt_ids))
|
||||
|
||||
return prompt_ids
|
||||
|
||||
@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
img_patch_id = vocab[IMAGE_PATCH_TOKEN]
|
||||
img_col_id = vocab[IM_COL_TOKEN]
|
||||
img_start_id = vocab[IM_START_TOKEN]
|
||||
img_end_id = vocab[IM_END_TOKEN]
|
||||
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
pooling_size = processor.pooling_size
|
||||
|
||||
img_patch_id = processor.image_patch_id
|
||||
img_col_id = processor.im_col_id
|
||||
img_start_id = processor.im_start_id
|
||||
img_end_id = processor.im_end_id
|
||||
image_processor = processor.image_processor
|
||||
image_token_length_w = image_processor.image_token_length_w
|
||||
image_token_length_h = image_processor.image_token_length_h
|
||||
pooling_size = POOLING_SIZE
|
||||
|
||||
extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
|
||||
extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
|
||||
@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = processor.get_patches_grid_size(
|
||||
ncols, nrows = self.info.get_patches_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]
|
||||
|
||||
Reference in New Issue
Block a user