[Core] Dynamic image size support for VLMs (#5276)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: ywang96 <ywang@roblox.com> Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
@@ -13,7 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -22,8 +24,8 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -34,10 +36,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -251,50 +255,22 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
||||
data: BatchedTensors
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""Shape: (batch_size, 2)"""
|
||||
"""
|
||||
Shape: `(batch_size, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
def _get_phi3v_image_feature_size(
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
h, w = input_height, input_width
|
||||
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
|
||||
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
|
||||
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_phi3v_image_feature_size(
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||
top_padding = int((target_height - height) / 2)
|
||||
@@ -304,7 +280,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
|
||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
transposed = False
|
||||
if width < height:
|
||||
@@ -329,27 +305,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
def _image_processor(ctx: InputContext,
|
||||
image: object) -> Dict[str, torch.Tensor]:
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
|
||||
def get_phi3v_image_feature_size(
|
||||
hf_config: PretrainedConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
num_crops = getattr(hf_config, "num_crops", 16)
|
||||
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
||||
height=input_height,
|
||||
hd_num=num_crops)
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != _calc_hd_transform_size(width=image.width,
|
||||
height=image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
|
||||
image = image.resize((w, h))
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin("image") \
|
||||
._default_input_mapper(ctx, image)
|
||||
raise TypeError(f"Invalid type for 'image': {type(image)}")
|
||||
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
|
||||
+ (new_height // 336 + 1) * 12
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
dummy_height, dummy_width = 8000, 50
|
||||
image_feature_size = get_phi3v_image_feature_size(
|
||||
ctx.get_hf_config(PretrainedConfig),
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Reserve this function to also handle placeholders for additional images
|
||||
# [ref: PR #5820]
|
||||
@lru_cache
|
||||
def _get_image_placeholder_token_ids(model_config: ModelConfig,
|
||||
idx: int) -> List[int]:
|
||||
assert idx > 0
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
# We need to get the token for "<", not "▁<"
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
|
||||
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
|
||||
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
|
||||
f"a<|image_{idx}|>", add_special_tokens=False)
|
||||
assert a_token_id == a_token_id_
|
||||
|
||||
return image_placeholder_token_ids
|
||||
|
||||
|
||||
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
w, h = image_data.size
|
||||
w, h = _calc_hd_transform_size(width=w, height=h)
|
||||
|
||||
image_feature_size = get_phi3v_image_feature_size(hf_config,
|
||||
input_width=w,
|
||||
input_height=h)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
if prompt is None:
|
||||
new_prompt = None
|
||||
else:
|
||||
if prompt.count("<|image|>") > 0:
|
||||
logger.warning("Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating <|image|> tokens.")
|
||||
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
|
||||
logger.warning("Multiple image input is not supported yet, "
|
||||
"so any extra image tokens will be treated "
|
||||
"as plain text.")
|
||||
|
||||
new_prompt = prompt
|
||||
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
|
||||
|
||||
new_token_ids: List[int] = []
|
||||
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
|
||||
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
|
||||
new_token_ids.append(multimodal_config.image_token_id)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
|
||||
break
|
||||
else:
|
||||
new_token_ids.append(prompt_token_ids[i])
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
llm_inputs,
|
||||
image_token_id=multimodal_config.image_token_id,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
@@ -363,6 +445,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vlm_config, config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
@@ -376,12 +460,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
|
||||
if pixel_values is not None and image_sizes is not None:
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
return None
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user