[Core] Dynamic image size support for VLMs (#5276)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: ywang96 <ywang@roblox.com>
Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com>
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2024-07-03 11:34:00 +08:00
committed by GitHub
parent 482045ee77
commit 9831aec49f
38 changed files with 1453 additions and 664 deletions

View File

@@ -1,4 +1,4 @@
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
@@ -10,7 +10,7 @@ from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@@ -21,13 +21,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
from vllm.sequence import IntermediateTensors, SamplerOutput
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length)
get_clip_patch_grid_length, input_processor_for_clip)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
from .llava import LlavaMultiModalProjector
from .utils import merge_vision_embeddings
logger = init_logger(__name__)
@@ -39,16 +40,27 @@ _KEYS_TO_MODIFY_MAPPING = {
class LlavaNextImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
data: BatchedTensors
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
"""
image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)"""
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
LlavaNextImageInputs = LlavaNextImagePixelInputs
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
# NOTE: new_height and new_width are further incremented to properly invert the
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
def _get_llava_next_num_unpadded_features(
height: int,
width: int,
@@ -56,7 +68,6 @@ def _get_llava_next_num_unpadded_features(
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
@@ -64,9 +75,13 @@ def _get_llava_next_num_unpadded_features(
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
if new_height % 2 == 1:
new_height += 1
current_height = new_height
else:
new_width = (width * current_height) // height
if new_width % 2 == 1:
new_width += 1
current_width = new_width
unpadded_features = current_height * current_width
@@ -74,7 +89,8 @@ def _get_llava_next_num_unpadded_features(
return (unpadded_features, newline_features)
def _get_llava_next_image_feature_size(
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
def get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
@@ -89,7 +105,9 @@ def _get_llava_next_image_feature_size(
)
base_feature_size = num_patches * num_patches
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
@@ -110,14 +128,16 @@ def _get_llava_next_image_feature_size(
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_llava_next_image_feature_size(
hf_config, input_height=dummy_height, input_width=dummy_width)
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
dummy_height = dummy_width = 448
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=dummy_height,
input_width=dummy_width,
)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
@@ -139,27 +159,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
multi_modal_data = llm_inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return llm_inputs
if isinstance(image, Image.Image):
model_config = ctx.model_config
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
width, height = image_data.size
image = image.resize((w, h))
image_feature_size = get_llava_next_image_feature_size(
hf_config,
input_height=height,
input_width=width,
)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
vision_config = hf_config.vision_config
raise TypeError(f"Invalid type for 'image': {type(image)}")
if isinstance(vision_config, CLIPVisionConfig):
return input_processor_for_clip(
model_config,
vision_config,
llm_inputs,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,
@@ -172,8 +212,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.vlm_config = vlm_config
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config=config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
@@ -196,24 +236,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.image_newline = nn.Parameter(
torch.empty(config.text_config.hidden_size))
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
_, num_channels, _, _ = self.vlm_config.image_input_shape
# Note that this is different from that of vLLM vision_language_config
# since the image is resized by the HuggingFace preprocessor
height = width = self.config.vision_config.image_size
if list(data.shape[2:]) != [num_channels, height, width]:
raise ValueError(
f"The expected image tensor shape is batch dimension plus "
f"num_patches plus {[num_channels, height, width]}. "
f"You supplied {data.shape}. "
f"If you are using vLLM's entrypoint, make sure your "
f"supplied image input is consistent with "
f"image_input_shape in engine args.")
return data
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
@@ -223,14 +245,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
if pixel_values is None or image_sizes is None:
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
@@ -240,7 +262,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
data=pixel_values,
image_sizes=self._validate_image_sizes(image_sizes),
)
@@ -267,15 +289,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
strategy=self.config.vision_feature_select_strategy,
)
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
patch_embeddings: torch.Tensor, *,
strategy: str) -> torch.Tensor:
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
if strategy == "flat":
return patch_embeddings.flatten(0, 1)
if strategy.startswith("spatial"):
orig_width, orig_height = image_size
height = width = self.config.vision_config.image_size \
// self.config.vision_config.patch_size
@@ -289,13 +310,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
other_patch_embeds = patch_embeddings[1:]
# image_aspect_ratio == "anyres"
# Note: We follow the "wrong" width/height order
# [ref: PR huggingface/transformers#31588]
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
(orig_width, orig_height),
image_size,
self.config.image_grid_pinpoints,
self.config.vision_config.image_size,
)
other_patch_embeds = other_patch_embeds \
.view(num_patch_width, num_patch_height, height, width, -1)
.view(num_patch_height, num_patch_width, height, width, -1)
if "unpad" in strategy:
other_patch_embeds = other_patch_embeds \
@@ -333,44 +356,53 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
def _process_image_pixels(
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
self,
inputs: LlavaNextImagePixelInputs,
) -> BatchedTensors:
assert self.vision_tower is not None
pixel_values = inputs["data"]
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
if isinstance(pixel_values, torch.Tensor):
b, num_patches, c, h, w = pixel_values.shape
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
stacked_patch_embeddings = self.multi_modal_projector(
stacked_image_features)
return stacked_patch_embeddings.view(
b, num_patches, *stacked_patch_embeddings.shape[1:])
num_patches_per_batch = [v.shape[0] for v in pixel_values]
stacked_pixel_values = torch.cat(pixel_values)
stacked_image_features = self._image_pixels_to_features(
self.vision_tower, stacked_pixel_values)
return stacked_image_features.view(b, num_patches,
*stacked_image_features.shape[-2:])
return [
self.multi_modal_projector(image_features) for image_features in
torch.split(stacked_image_features, num_patches_per_batch)
]
def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
patch_embeddings = self.multi_modal_projector(image_features)
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
patch_embeddings = self._process_image_pixels(image_input)
image_sizes = image_input.get("image_sizes")
if image_sizes is None:
batch_size = image_input["data"].shape[0]
batch_size = len(image_input["data"])
vision_config = self.config.vision_config
default_width = default_height = vision_config.image_size
image_sizes = torch.as_tensor([[default_width, default_height]
default_height = default_width = vision_config.image_size
image_sizes = torch.as_tensor([[default_height, default_width]
for _ in range(batch_size)])
merged_patch_embeddings = [
return [
self._merge_image_patch_embeddings(image_sizes[i],
patch_features,
patch_features_batch,
strategy="spatial_unpad")
for i, patch_features in enumerate(patch_embeddings)
for i, patch_features_batch in enumerate(patch_embeddings)
]
return torch.stack(merged_patch_embeddings, dim=0)
def forward(
self,
input_ids: torch.Tensor,
@@ -404,8 +436,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each grid patch for each input image.
Expects a batch with shape `[1, num_patches, 3, 336, 336]`.
image_sizes: The original `(width, height)` for each input image.
Expects a batch with shape `[1, num_patches, 3, h, w]`.
image_sizes: The original `(height, width)` for each input image.
Expects a batch with shape `[1, 2]`.
See also: