[Misc] Move processors to transformers_utils (#35953)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-09 11:31:39 +08:00
committed by GitHub
parent bd2659a566
commit d62856b928
13 changed files with 507 additions and 595 deletions

View File

@@ -4,7 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from functools import partial
from itertools import islice
from typing import Annotated
@@ -13,9 +13,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from transformers import (
BaseImageProcessor,
BatchFeature,
PretrainedConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -1017,117 +1019,28 @@ def select_tiling(
return candidate_tilings[ix]
class MolmoProcessorWrapper:
"""
Wraps `MolmoProcessor` so that it can be called directly.
def _as_2tuple(x: int | tuple[int, int]) -> tuple[int, int]:
if isinstance(x, int):
return x, x
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
"""
return x
def __init__(self, processor: ProcessorMixin):
super().__init__()
self.processor = processor
@cached_property
def vocab(self) -> dict[str, int]:
return self.processor.tokenizer.vocab # type: ignore
@cached_property
def max_crops(self) -> int:
image_processor = self.processor.image_processor # type: ignore
max_crops = image_processor.max_crops
assert isinstance(max_crops, int)
return max_crops
@cached_property
def base_image_input_size(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
base_image_input_size = image_processor.base_image_input_size
if isinstance(base_image_input_size, int):
return base_image_input_size, base_image_input_size
return tuple(base_image_input_size)
@cached_property
def image_patch_size(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_patch_size = image_processor.image_patch_size
assert isinstance(image_patch_size, int)
return image_patch_size
@cached_property
def overlap_margins(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
left_margin, right_margin = image_processor.overlap_margins
assert isinstance(left_margin, int)
assert isinstance(right_margin, int)
return left_margin, right_margin
@cached_property
def image_token_length_w(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_token_length_w = image_processor.image_token_length_w
assert isinstance(image_token_length_w, int)
return image_token_length_w
@cached_property
def image_token_length_h(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_token_length_h = image_processor.image_token_length_h
assert isinstance(image_token_length_h, int)
return image_token_length_h
@property
def message_format(self) -> str | None:
return "role"
@property
def always_start_with_space(self) -> bool:
return True
@cached_property
def image_patch_id(self) -> int:
return self.vocab[IMAGE_PATCH_TOKEN]
@cached_property
def im_col_id(self) -> int:
return self.vocab[IM_COL_TOKEN]
@cached_property
def im_start_id(self) -> int:
return self.vocab[IM_START_TOKEN]
@cached_property
def im_end_id(self) -> int:
return self.vocab[IM_END_TOKEN]
@property
def pooling_size(self) -> int:
return POOLING_SIZE
class MolmoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def select_tiling(
self,
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
max_crops = self.max_crops
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
max_crops = image_processor.max_crops
left_margin, right_margin = image_processor.overlap_margins
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
base_image_input_d = image_processor.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper:
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
pooling_size = self.pooling_size
left_margin, right_margin = image_processor.overlap_margins
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
base_image_input_d = image_processor.image_patch_size
pooling_size = POOLING_SIZE
crop_patches = base_image_input_size[0] // base_image_input_d
tiling_w, tiling_h = self.select_tiling(
image_height=image_height,
image_width=image_width,
image_processor=image_processor,
)
nrows, ncols = get_patches_grid_size(
@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper:
return ncols, nrows
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
outputs = self.processor.process( # type: ignore
text, images, **kwargs
)
if images is None:
images = []
if not isinstance(images, list):
images = [images]
input_ids: torch.Tensor = outputs.pop("input_ids")
outputs["input_ids"] = input_ids.unsqueeze(0)
image_input_idx = outputs.pop("image_input_idx", None)
if image_input_idx is not None:
feat_is_patch = image_input_idx >= 0
tilings = [
self.select_tiling(
image_width=image.size[0],
image_height=image.size[1],
)
for image in images
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
outputs["image_input_idx"] = image_input_idx
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
return BatchFeature(outputs)
class MolmoProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
processor = self.ctx.get_hf_processor(**kwargs)
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: MolmoProcessorWrapper,
image_processor: BaseImageProcessor,
) -> int:
ncols, nrows = processor.get_patches_grid_size(
ncols, nrows = self.get_patches_grid_size(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
)
pooling_size = processor.pooling_size
pooling_size = POOLING_SIZE
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
image_token_length_w = image_processor.image_token_length_w
image_token_length_h = image_processor.image_token_length_h
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
extra = 2 + (image_token_length_w + 1) * image_token_length_h
@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor = processor.image_processor
tilings = get_candidate_tilings(processor.max_crops)
base_h, base_w = processor.base_image_input_size
tilings = get_candidate_tilings(image_processor.max_crops)
base_h, base_w = _as_2tuple(image_processor.base_image_input_size)
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in tilings:
@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor,
image_processor=image_processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self.info.get_hf_processor(**mm_kwargs)
processed_outputs = self.info.ctx.call_hf_processor(
hf_processor.process,
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
tokenizer = hf_processor.tokenizer
image_patch_id = tokenizer.vocab[IMAGE_PATCH_TOKEN]
image_processor = hf_processor.image_processor
input_ids: torch.Tensor = processed_outputs.pop("input_ids")
processed_outputs["input_ids"] = input_ids.unsqueeze(0)
if (images := mm_data.get("images")) is not None:
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
feat_is_patch = processed_outputs["image_input_idx"] >= 0
tilings = [
self.info.select_tiling(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
for image_size in image_sizes
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
processed_outputs["num_crops"] = num_crops
processed_outputs["img_patch_id"] = image_patch_id
return processed_outputs
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
# The chat template is already applied to the prompt tokens
# Use message_format="none" to avoid applying it again
# Prepend an empty space if `always_start_with_space` is True
tokens = processor.processor.get_tokens_input( # type: ignore
tokens = processor.get_tokens_input(
self.info.get_tokenizer().decode(prompt_tokens),
message_format="none",
always_start_with_space=processor.always_start_with_space,
always_start_with_space=True,
)
# Prepend a BOS token id to the tokens
processed_data = self.info.ctx.call_hf_processor(
processor, # type: ignore
processor.process,
dict(tokens=tokens),
)
(prompt_ids,) = processed_data.pop("input_ids").tolist()
prompt_ids = processed_data.pop("input_ids").tolist()
print(prompt_ids, len(prompt_ids))
return prompt_ids
@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
img_patch_id = vocab[IMAGE_PATCH_TOKEN]
img_col_id = vocab[IM_COL_TOKEN]
img_start_id = vocab[IM_START_TOKEN]
img_end_id = vocab[IM_END_TOKEN]
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
pooling_size = processor.pooling_size
img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id
img_start_id = processor.im_start_id
img_end_id = processor.im_end_id
image_processor = processor.image_processor
image_token_length_w = image_processor.image_token_length_w
image_token_length_h = image_processor.image_token_length_h
pooling_size = POOLING_SIZE
extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = processor.get_patches_grid_size(
ncols, nrows = self.info.get_patches_grid_size(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]