[Core] Dynamic image size support for VLMs (#5276)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: ywang96 <ywang@roblox.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2024-07-03 11:34:00 +08:00
committed by GitHub
parent 482045ee77
commit 9831aec49f
38 changed files with 1453 additions and 664 deletions

View File

@@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
import re
from functools import lru_cache
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import numpy as np
import torch
@@ -22,8 +24,8 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@@ -34,10 +36,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.multimodal.image import cached_get_tokenizer
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
input_processor_for_clip)
from .interfaces import SupportsVision
logger = init_logger(__name__)
@@ -251,50 +255,22 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
data: BatchedTensors
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
"""
image_sizes: torch.Tensor
"""Shape: (batch_size, 2)"""
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
def _get_phi3v_image_feature_size(
*,
input_height: int,
input_width: int,
) -> int:
h, w = input_height, input_width
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_phi3v_image_feature_size(
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
@@ -304,7 +280,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
return padded_width, padded_height
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
transposed = False
if width < height:
@@ -329,27 +305,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
return padded_width, padded_height
def _image_processor(ctx: InputContext,
image: object) -> Dict[str, torch.Tensor]:
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
hf_config: PretrainedConfig,
*,
input_height: int,
input_width: int,
) -> int:
num_crops = getattr(hf_config, "num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != _calc_hd_transform_size(width=image.width,
height=image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
raise TypeError(f"Invalid type for 'image': {type(image)}")
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
+ (new_height // 336 + 1) * 12
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
# Result in the max possible feature size (h:w = 16:1)
dummy_height, dummy_width = 8000, 50
image_feature_size = get_phi3v_image_feature_size(
ctx.get_hf_config(PretrainedConfig),
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
# Reserve this function to also handle placeholders for additional images
# [ref: PR #5820]
@lru_cache
def _get_image_placeholder_token_ids(model_config: ModelConfig,
idx: int) -> List[int]:
assert idx > 0
tokenizer = cached_get_tokenizer(model_config.tokenizer)
# We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
f"a<|image_{idx}|>", add_special_tokens=False)
assert a_token_id == a_token_id_
return image_placeholder_token_ids
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
model_config = ctx.model_config
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(PretrainedConfig)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
w, h = image_data.size
w, h = _calc_hd_transform_size(width=w, height=h)
image_feature_size = get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
prompt = llm_inputs.get("prompt")
if prompt is None:
new_prompt = None
else:
if prompt.count("<|image|>") > 0:
logger.warning("Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens.")
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
logger.warning("Multiple image input is not supported yet, "
"so any extra image tokens will be treated "
"as plain text.")
new_prompt = prompt
prompt_token_ids = llm_inputs["prompt_token_ids"]
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
new_token_ids: List[int] = []
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
new_token_ids.append(multimodal_config.image_token_id)
# No need to further scan the list since we only replace once
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
break
else:
new_token_ids.append(prompt_token_ids[i])
# NOTE: Create a defensive copy of the original inputs
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data)
return input_processor_for_clip(
model_config,
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
llm_inputs,
image_token_id=multimodal_config.image_token_id,
image_feature_size_override=image_feature_size,
)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self,
@@ -363,6 +445,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self.vlm_config = vlm_config
self.model = LlamaModel(config, cache_config, quant_config)
# TODO: Optionally initializes this for supporting embeddings.
self.vision_embed_tokens = Phi3HDImageEmbedding(
vlm_config, config, self.model.embed_tokens)
self.lm_head = ParallelLMHead(config.vocab_size,
@@ -376,12 +460,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
if pixel_values is not None and image_sizes is not None:
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,
image_sizes=image_sizes)
if pixel_values is None:
return None
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(type="pixel_values",
data=pixel_values,
image_sizes=image_sizes)
def forward(self,
input_ids: torch.Tensor,