[Model][VLM] Support multi-images inputs for InternVL2 models (#8201)
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
import itertools
|
||||
import re
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
@@ -26,6 +27,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.utils import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
@@ -95,8 +97,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
||||
|
||||
|
||||
def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
max_num: int,
|
||||
image_size: int) -> Tuple[int, int, int]:
|
||||
max_num: int, image_size: int,
|
||||
use_thumbnail: bool) -> Tuple[int, int, int]:
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
@@ -114,17 +116,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int,
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if use_thumbnail and blocks > 1:
|
||||
blocks += 1
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int,
|
||||
image_size: int,
|
||||
use_thumbnail: int) -> List[Image.Image]:
|
||||
use_thumbnail: bool) -> List[Image.Image]:
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
# calculate the number of blocks without thumbnail
|
||||
blocks, target_width, target_height = calculate_num_blocks(
|
||||
orig_width, orig_height, min_num, max_num, image_size)
|
||||
orig_width,
|
||||
orig_height,
|
||||
min_num,
|
||||
max_num,
|
||||
image_size,
|
||||
use_thumbnail=False)
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
@@ -197,17 +208,23 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
downsample_ratio)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
max_num = hf_config.max_dynamic_patch
|
||||
use_thumbnail = hf_config.use_thumbnail
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
min_num = hf_config.min_dynamic_patch
|
||||
max_num = hf_config.max_dynamic_patch
|
||||
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
||||
max_num, image_size)
|
||||
# add thumbnail image if num_blocks > 1
|
||||
if hf_config.use_thumbnail and num_blocks > 1:
|
||||
num_blocks += 1
|
||||
image_feature_size = num_blocks * num_patches
|
||||
|
||||
max_num, image_size,
|
||||
use_thumbnail)
|
||||
image_feature_size = [num_blocks * num_patches]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = []
|
||||
for image in image_data:
|
||||
width, height = image.size
|
||||
num_blocks, _, _ = calculate_num_blocks(width, height, min_num,
|
||||
max_num, image_size,
|
||||
use_thumbnail)
|
||||
image_feature_size.append(num_blocks * num_patches)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
num_images, image_feature_size, hidden_size = image_data.shape
|
||||
else:
|
||||
@@ -220,8 +237,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END
|
||||
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
||||
|
||||
new_prompt = prompt
|
||||
image_idx = sorted(map(int, re.findall(r"Image-(\d+): <image>\n", prompt)))
|
||||
for idx, feature_size in enumerate(image_feature_size, start=1):
|
||||
image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END
|
||||
if not image_idx:
|
||||
image_prompt = f"Image-{idx}: {image_prompt}"
|
||||
new_prompt = new_prompt.replace('<image>', image_prompt, 1)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
return LLMInputs(prompt=prompt,
|
||||
@@ -245,6 +268,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||
use_thumbnail=use_thumbnail)
|
||||
# Add an N dimension for number of images per prompt (currently 1).
|
||||
data = data.unsqueeze(0)
|
||||
elif is_list_of(data, Image.Image):
|
||||
data = [
|
||||
image_to_pixel_values(img,
|
||||
image_size,
|
||||
min_num,
|
||||
max_num,
|
||||
use_thumbnail=use_thumbnail) for img in data
|
||||
]
|
||||
data = torch.stack(data)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
|
||||
Reference in New Issue
Block a user