[VLM] Merged multi-modal processor for InternVL-based models (#12553)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -6,35 +6,37 @@
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import re
|
||||
from functools import cached_property, partial
|
||||
from abc import ABC, abstractmethod
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
TypedDict, TypeVar, Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import PretrainedConfig
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.models.intern_vit import (InternVisionModel,
|
||||
InternVisionPatchModel)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.parse import (ImageEmbeddingItems, ImageProcessorItems,
|
||||
ImageSize, MultiModalDataItems)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement,
|
||||
PromptReplacementDetails)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
|
||||
maybe_prefix, merge_multimodal_embeddings)
|
||||
@@ -75,22 +77,27 @@ InternVLImageInputs = Union[InternVLImagePixelInputs,
|
||||
InternVLImageEmbeddingInputs]
|
||||
|
||||
|
||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size):
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size: int):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
return T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
|
||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
||||
image_size):
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def find_closest_aspect_ratio(
|
||||
aspect_ratio: float,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
*,
|
||||
width: int,
|
||||
height: int,
|
||||
image_size: int,
|
||||
) -> tuple[int, int]:
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
@@ -106,67 +113,82 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
||||
return best_ratio
|
||||
|
||||
|
||||
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
max_num: int, image_size: int,
|
||||
use_thumbnail: bool) -> Tuple[int, int, int]:
|
||||
def resolve_internvl_min_max_num(
|
||||
*,
|
||||
min_dynamic_patch: int,
|
||||
max_dynamic_patch: int,
|
||||
dynamic_image_size: bool,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int]:
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
|
||||
if use_thumbnail and max_dynamic_patch != 1:
|
||||
max_dynamic_patch += 1
|
||||
|
||||
return min_dynamic_patch, max_dynamic_patch
|
||||
|
||||
|
||||
def get_internvl_target_ratios(
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
) -> list[tuple[int, int]]:
|
||||
target_ratios = {(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1) if min_num <= i * j <= max_num}
|
||||
return sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
|
||||
def calculate_internvl_targets(
|
||||
*,
|
||||
orig_width: int,
|
||||
orig_height: int,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> tuple[int, int, int]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1) for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
||||
target_ratios, orig_width,
|
||||
orig_height, image_size)
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio,
|
||||
target_ratios,
|
||||
width=orig_width,
|
||||
height=orig_height,
|
||||
image_size=image_size,
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if use_thumbnail and blocks > 1:
|
||||
|
||||
# add thumbnail image if num_blocks != 1
|
||||
if use_thumbnail and blocks != 1:
|
||||
blocks += 1
|
||||
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
def calculate_num_blocks_wrapper(
|
||||
hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
image_size = hf_config.vision_config.image_size
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
return partial(calculate_num_blocks,
|
||||
min_num=min_num,
|
||||
max_num=max_dynamic_patch,
|
||||
image_size=image_size,
|
||||
use_thumbnail=use_thumbnail)
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||
image_size: int,
|
||||
use_thumbnail: bool) -> List[Image.Image]:
|
||||
def dynamic_preprocess_internvl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
target_ratios: list[tuple[int, int]],
|
||||
image_size: int,
|
||||
use_thumbnail: bool,
|
||||
) -> list[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks without thumbnail
|
||||
blocks, target_width, target_height = calculate_num_blocks(
|
||||
orig_width,
|
||||
orig_height,
|
||||
min_num,
|
||||
max_num,
|
||||
image_size,
|
||||
use_thumbnail=False)
|
||||
blocks, target_width, target_height = calculate_internvl_targets(
|
||||
orig_width=orig_width,
|
||||
orig_height=orig_height,
|
||||
target_ratios=target_ratios,
|
||||
image_size=image_size,
|
||||
use_thumbnail=False,
|
||||
)
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
@@ -178,301 +200,463 @@ def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
|
||||
assert len(processed_images) == blocks
|
||||
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
|
||||
return processed_images
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int,
|
||||
max_num: int, use_thumbnail: bool) -> torch.Tensor:
|
||||
def image_to_pixel_values_internvl(
|
||||
image: Image.Image,
|
||||
*,
|
||||
input_size: int,
|
||||
min_num: int,
|
||||
max_num: int,
|
||||
use_thumbnail: bool,
|
||||
) -> torch.Tensor:
|
||||
target_ratios = get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
transform = build_transform(input_size=input_size)
|
||||
images = dynamic_preprocess(image,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
images = dynamic_preprocess_internvl(
|
||||
image,
|
||||
target_ratios=target_ratios,
|
||||
image_size=input_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
pixel_values = torch.stack([transform(image) for image in images])
|
||||
return pixel_values
|
||||
|
||||
|
||||
def image_to_pixel_values_wrapper(
|
||||
hf_config: PretrainedConfig,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
image_size = hf_config.vision_config.image_size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
class BaseInternVLProcessor(ABC):
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
return partial(image_to_pixel_values,
|
||||
input_size=image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_dynamic_patch,
|
||||
use_thumbnail=use_thumbnail)
|
||||
|
||||
|
||||
def get_internvl_num_patches(hf_config: PretrainedConfig):
|
||||
vision_config = hf_config.vision_config
|
||||
downsample_ratio = hf_config.downsample_ratio
|
||||
image_size = vision_config.image_size
|
||||
patch_size = vision_config.patch_size
|
||||
return int(
|
||||
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
|
||||
(downsample_ratio**2))
|
||||
|
||||
|
||||
def get_max_internvl_image_tokens(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
if use_thumbnail and max_dynamic_patch > 1:
|
||||
max_dynamic_patch += 1
|
||||
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
return num_patches * max_dynamic_patch
|
||||
|
||||
|
||||
def get_max_internvl_image_size(
|
||||
ctx: InputContext,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
image_size = hf_config.vision_config.image_size
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = hf_config.dynamic_image_size
|
||||
|
||||
max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
if use_thumbnail and max_dynamic_patch > 1:
|
||||
max_dynamic_patch += 1
|
||||
width = image_size * max_dynamic_patch
|
||||
height = image_size
|
||||
return width, height
|
||||
|
||||
|
||||
class InternVLInputPipeline:
|
||||
The code to insert image tokens is based on:
|
||||
https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py#L252
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_start_token: str,
|
||||
img_end_token: str,
|
||||
img_context_token: str,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.img_start_token = img_start_token
|
||||
self.img_end_token = img_end_token
|
||||
self.img_context_token = img_context_token
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
def _create_image_prompt(self, feature_size: int, num_patches: int) -> str:
|
||||
return (self.img_start_token + self.img_context_token * feature_size +
|
||||
self.img_end_token)
|
||||
image_size: int = config.vision_config.image_size
|
||||
patch_size: int = config.vision_config.patch_size
|
||||
|
||||
def _expand_image_prompt(
|
||||
if dynamic_image_size is None:
|
||||
dynamic_image_size = config.dynamic_image_size
|
||||
assert isinstance(dynamic_image_size, bool)
|
||||
|
||||
if max_dynamic_patch is None:
|
||||
max_dynamic_patch = config.max_dynamic_patch
|
||||
assert isinstance(max_dynamic_patch, int)
|
||||
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
self.image_size = image_size
|
||||
self.min_dynamic_patch: int = config.min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail: bool = config.use_thumbnail
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def image_token_id(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def get_image_repl_features(
|
||||
self,
|
||||
prompt: str,
|
||||
feature_sizes: List[int],
|
||||
num_patches: int,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
image_idx = sorted(
|
||||
map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
|
||||
raise NotImplementedError
|
||||
|
||||
new_prompt = prompt
|
||||
for idx, feature_size in enumerate(feature_sizes, start=1):
|
||||
image_prompt = self._create_image_prompt(feature_size, num_patches)
|
||||
if not image_idx:
|
||||
image_prompt = f"Image-{idx}: {image_prompt}"
|
||||
|
||||
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
|
||||
|
||||
return new_prompt
|
||||
|
||||
def input_processor(
|
||||
@abstractmethod
|
||||
def get_image_repl_full(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def resolve_min_max_num(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> DecoderOnlyInputs:
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
) -> tuple[int, int]:
|
||||
min_dynamic_patch = self.min_dynamic_patch
|
||||
max_dynamic_patch = (self.max_dynamic_patch if max_dynamic_patch
|
||||
is None else max_dynamic_patch)
|
||||
dynamic_image_size = (self.dynamic_image_size if dynamic_image_size
|
||||
is None else dynamic_image_size)
|
||||
use_thumbnail = (self.use_thumbnail
|
||||
if use_thumbnail is None else use_thumbnail)
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config()
|
||||
return resolve_internvl_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
num_patches = get_internvl_num_patches(hf_config)
|
||||
num_blocks_calculator = calculate_num_blocks_wrapper(
|
||||
hf_config, max_dynamic_patch, dynamic_image_size)
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
image_feature_sizes = [num_blocks * num_patches]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_sizes = []
|
||||
for image in image_data:
|
||||
width, height = image.size
|
||||
num_blocks, _, _ = num_blocks_calculator(width, height)
|
||||
image_feature_sizes.append(num_blocks * num_patches)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
image_feature_sizes = [image_feature_size]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
|
||||
new_prompt = self._expand_image_prompt(prompt, image_feature_sizes,
|
||||
num_patches)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
img_context_token_id = tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False)
|
||||
assert len(img_context_token_id) == 1, \
|
||||
(f"Invalid image token '{self.img_context_token}': A valid image "
|
||||
f"token encodes to a single token ID, got {img_context_token_id}.")
|
||||
img_context_token_id = img_context_token_id[0]
|
||||
|
||||
# Get precise tracking of placeholder positions
|
||||
token_idx = image_idx = 0
|
||||
placeholder_ranges = []
|
||||
while token_idx < len(new_prompt_token_ids):
|
||||
if new_prompt_token_ids[token_idx] == img_context_token_id:
|
||||
curr_image_featue_size = image_feature_sizes[image_idx]
|
||||
placeholder_ranges.append(
|
||||
PlaceholderRange(offset=token_idx,
|
||||
length=curr_image_featue_size))
|
||||
image_idx += 1
|
||||
token_idx += curr_image_featue_size
|
||||
else:
|
||||
token_idx += 1
|
||||
|
||||
return token_inputs(
|
||||
prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"image": placeholder_ranges})
|
||||
|
||||
def input_mapper(
|
||||
def resolve_target_ratios(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
data: object,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
hf_config = ctx.get_hf_config()
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
) -> list[tuple[int, int]]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=use_thumbnail,
|
||||
)
|
||||
|
||||
image_pixel_values_mapper = image_to_pixel_values_wrapper(
|
||||
hf_config, max_dynamic_patch, dynamic_image_size)
|
||||
if isinstance(data, Image.Image):
|
||||
data = image_pixel_values_mapper(data)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
data = data.unsqueeze(0)
|
||||
elif is_list_of(data, Image.Image):
|
||||
# we can't stack here because images may have different num_patches
|
||||
data = [image_pixel_values_mapper(img) for img in data]
|
||||
else:
|
||||
return MultiModalKwargs({"image_embeds": data})
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
image_token_id = tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")[0]
|
||||
return get_internvl_target_ratios(min_num, max_num)
|
||||
|
||||
return MultiModalKwargs({
|
||||
"pixel_values": data,
|
||||
"image_token_id": image_token_id
|
||||
})
|
||||
|
||||
def dummy_data(
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
target_ratios = self.resolve_target_ratios(
|
||||
use_thumbnail=False, # Applied in calculate_targets
|
||||
)
|
||||
|
||||
num_patches, _, _ = calculate_internvl_targets(
|
||||
orig_width=image_width,
|
||||
orig_height=image_height,
|
||||
image_size=self.image_size,
|
||||
target_ratios=target_ratios,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
)
|
||||
|
||||
return num_patches * self.num_image_token
|
||||
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
use_thumbnail=False, # Applied in image_to_pixel_values
|
||||
)
|
||||
|
||||
return [
|
||||
image_to_pixel_values_internvl(
|
||||
image,
|
||||
input_size=self.image_size,
|
||||
min_num=min_num,
|
||||
max_num=max_num,
|
||||
use_thumbnail=self.use_thumbnail,
|
||||
) for image in images
|
||||
]
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, list[str]]] = None,
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values_lst = self._images_to_pixel_values_lst(
|
||||
images,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
image_inputs = {
|
||||
"pixel_values_flat": torch.cat(pixel_values_lst),
|
||||
"image_num_patches": list(map(len, pixel_values_lst)),
|
||||
}
|
||||
|
||||
for pixel_values in pixel_values_lst:
|
||||
num_patches = pixel_values.shape[0]
|
||||
feature_size = num_patches * self.num_image_token
|
||||
|
||||
image_repl = self.get_image_repl_full(feature_size,
|
||||
num_patches)
|
||||
text = [t.replace('<image>', image_repl, 1) for t in text]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class InternVLProcessor(BaseInternVLProcessor):
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[IMG_CONTEXT]
|
||||
|
||||
def get_image_repl_features(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
return IMG_CONTEXT * feature_size
|
||||
|
||||
def get_image_repl_full(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
) -> str:
|
||||
features = self.get_image_repl_features(feature_size, num_patches)
|
||||
return IMG_START + features + IMG_END
|
||||
|
||||
|
||||
class BaseInternVLProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
@abstractmethod
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
) -> BaseInternVLProcessor:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[BaseInternVLProcessor],
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
return processor.get_num_image_tokens(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
)
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return self.get_num_image_tokens(
|
||||
image_width=target_width,
|
||||
image_height=target_height,
|
||||
processor=None,
|
||||
)
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
|
||||
base_size = processor.image_size
|
||||
target_ratios = processor.resolve_target_ratios()
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for wr, hr in target_ratios:
|
||||
width, height = base_size * wr, base_size * hr
|
||||
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
processor=processor,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
largest_feature_pinpoint = ImageSize(width=width,
|
||||
height=height)
|
||||
|
||||
if largest_feature_size == 0 or largest_feature_pinpoint is None:
|
||||
raise ValueError("Cannot have a largest feature size of 0!")
|
||||
|
||||
return largest_feature_pinpoint
|
||||
|
||||
|
||||
_I = TypeVar("_I", bound=BaseInternVLProcessingInfo)
|
||||
|
||||
|
||||
class InternVLDummyInputsBuilder(BaseDummyInputsBuilder[_I]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<image>" * num_images,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
class InternVLMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
image_token_id = self.info.get_hf_processor(**mm_kwargs).image_token_id
|
||||
image_data = mm_data.get("images", [])
|
||||
assert isinstance(image_data, list)
|
||||
|
||||
# Since there may be extra tokens in the feature placeholders,
|
||||
# we need to pass the image token ID to the model to select the
|
||||
# tokens to merge from the vision encoder outputs
|
||||
processed_outputs["image_token_id"] = [image_token_id
|
||||
] * len(image_data)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
image_num_patches = hf_inputs.get("image_num_patches", torch.empty(0))
|
||||
|
||||
return dict(
|
||||
pixel_values_flat=MultiModalFieldConfig.flat_from_sizes(
|
||||
"image", image_num_patches),
|
||||
image_num_patches=MultiModalFieldConfig.batched("image"),
|
||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
||||
image_token_id=MultiModalFieldConfig.batched("image"),
|
||||
)
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
if "image_num_patches" in out_mm_kwargs:
|
||||
image_num_patches = out_mm_kwargs["image_num_patches"]
|
||||
assert isinstance(image_num_patches, torch.Tensor)
|
||||
image_num_patches = image_num_patches.tolist()
|
||||
elif "image_embeds" in out_mm_kwargs:
|
||||
# TODO: Use image size information in dictionary embedding inputs
|
||||
# to compute num_patches (similar to Qwen2-VL)
|
||||
image_num_patches = [None] * len(out_mm_kwargs["image_embeds"])
|
||||
else:
|
||||
image_num_patches = []
|
||||
|
||||
def get_replacement_internvl(item_idx: int):
|
||||
images = mm_items.get_items(
|
||||
"image", (ImageEmbeddingItems, ImageProcessorItems))
|
||||
|
||||
if isinstance(images, ImageEmbeddingItems):
|
||||
feature_size = images.get_feature_size(item_idx)
|
||||
else:
|
||||
image_size = images.get_image_size(item_idx)
|
||||
feature_size = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
processor=hf_processor,
|
||||
)
|
||||
|
||||
num_patches = image_num_patches[item_idx]
|
||||
if num_patches is not None:
|
||||
assert isinstance(num_patches, int)
|
||||
|
||||
return PromptReplacementDetails(
|
||||
full=hf_processor.get_image_repl_full(feature_size,
|
||||
num_patches),
|
||||
features=hf_processor.get_image_repl_features(
|
||||
feature_size, num_patches),
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target="<image>",
|
||||
replacement=get_replacement_internvl,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class InternVLProcessingInfo(BaseInternVLProcessingInfo):
|
||||
|
||||
def get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
hf_config = ctx.get_hf_config()
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(
|
||||
ctx,
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(
|
||||
model_config.tokenizer,
|
||||
trust_remote_code=model_config.trust_remote_code)
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
hf_config.vision_config,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=tokenizer.encode(self.img_context_token,
|
||||
add_special_tokens=False)[0],
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
max_image_width, max_image_height = get_max_internvl_image_size(
|
||||
ctx,
|
||||
) -> InternVLProcessor:
|
||||
return InternVLProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
max_dynamic_patch=max_dynamic_patch,
|
||||
dynamic_image_size=dynamic_image_size,
|
||||
)
|
||||
|
||||
mm_data = dummy_image_for_clip(
|
||||
hf_config.vision_config,
|
||||
num_images,
|
||||
image_width_override=max_image_width,
|
||||
image_height_override=max_image_height,
|
||||
)
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
|
||||
input_pipeline = InternVLInputPipeline(IMG_START, IMG_END, IMG_CONTEXT)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_pipeline.input_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(input_pipeline.dummy_data)
|
||||
@INPUT_REGISTRY.register_input_processor(input_pipeline.input_processor)
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
InternVLMultiModalProcessor,
|
||||
info=InternVLProcessingInfo,
|
||||
dummy_inputs=InternVLDummyInputsBuilder)
|
||||
class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
|
||||
@@ -621,11 +805,11 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[InternVLImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_token_id = kwargs.pop("image_token_id", None)
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None and image_embeds is None:
|
||||
if pixel_values_flat is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if image_embeds is not None:
|
||||
@@ -638,31 +822,30 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
data=flatten_bn(image_embeds),
|
||||
)
|
||||
|
||||
self.img_context_token_id = image_token_id[0]
|
||||
image_token_id = kwargs["image_token_id"]
|
||||
assert isinstance(image_token_id, torch.Tensor)
|
||||
self.img_context_token_id = image_token_id.flatten().unique().item()
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
if pixel_values_flat is not None:
|
||||
if not isinstance(pixel_values_flat, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
f"Got type: {type(pixel_values_flat)}")
|
||||
|
||||
assert isinstance(image_num_patches, (torch.Tensor, list))
|
||||
|
||||
patches_per_image = []
|
||||
for request_pixel_values in pixel_values:
|
||||
for image_pixel_values in request_pixel_values:
|
||||
patches_per_image.append(image_pixel_values.shape[0])
|
||||
# We need to flatten (B, N, P) to (B*N*P),
|
||||
# so we call flatten_bn twice.
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(flatten_bn(pixel_values), concat=True)),
|
||||
patches_per_image=patches_per_image)
|
||||
flatten_bn(pixel_values_flat, concat=True)),
|
||||
patches_per_image=flatten_bn(image_num_patches,
|
||||
concat=True).tolist())
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: InternVLImageInputs,
|
||||
) -> Tuple[torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@@ -689,7 +872,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
image_embeds = image_embeds.split(image_feature_sizes)
|
||||
return image_embeds
|
||||
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> None:
|
||||
if self.is_mono:
|
||||
self.visual_token_mask = (
|
||||
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||
|
||||
Reference in New Issue
Block a user