[Core] Dynamic image size support for VLMs (#5276)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: ywang96 <ywang@roblox.com> Co-authored-by: xwjiang2010 <87673679+xwjiang2010@users.noreply.github.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
@@ -8,10 +8,14 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPAttention
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.image import (cached_get_tokenizer,
|
||||
repeat_and_pad_image_tokens)
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
@@ -64,6 +68,39 @@ def dummy_image_for_clip(
|
||||
return {"image": image}
|
||||
|
||||
|
||||
def input_processor_for_clip(
|
||||
model_config: ModelConfig,
|
||||
hf_config: CLIPVisionConfig,
|
||||
llm_inputs: LLMInputs,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
new_prompt, new_token_ids = repeat_and_pad_image_tokens(
|
||||
tokenizer,
|
||||
llm_inputs.get("prompt"),
|
||||
llm_inputs["prompt_token_ids"],
|
||||
image_token_id=image_token_id,
|
||||
repeat_count=image_feature_size,
|
||||
)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
class CLIPVisionEmbeddings(nn.Module):
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from transformers import CLIPVisionConfig, LlavaConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -20,8 +20,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
"language_model.lm_head": "lm_head",
|
||||
@@ -51,28 +53,10 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: torch.Tensor,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""In place merges in vision_embeddings with inputs_embeds."""
|
||||
mask = (input_ids == image_token_id)
|
||||
|
||||
image_feature_size = vision_embeddings.shape[0] * vision_embeddings.shape[1]
|
||||
if mask.sum() != image_feature_size:
|
||||
raise ValueError(f"image_feature_size should be {image_feature_size}, "
|
||||
f"but found: {mask.sum()}")
|
||||
|
||||
inputs_embeds[mask] = vision_embeddings.view(image_feature_size,
|
||||
vision_embeddings.shape[-1])
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, num_channels, height, width)"""
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
LlavaImageInputs = LlavaImagePixelInputs
|
||||
@@ -96,8 +80,30 @@ def dummy_data_for_llava(ctx: InputContext, seq_len: int):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava)
|
||||
class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
@@ -112,7 +118,6 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -10,7 +10,7 @@ from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -21,13 +21,14 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_patch_grid_length)
|
||||
get_clip_patch_grid_length, input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
from .llava import LlavaMultiModalProjector
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -39,16 +40,27 @@ _KEYS_TO_MODIFY_MAPPING = {
|
||||
|
||||
class LlavaNextImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
||||
data: BatchedTensors
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
"""
|
||||
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
"""Shape: (batch_size, 2)"""
|
||||
"""
|
||||
Shape: `(batch_size, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
||||
|
||||
|
||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L91
|
||||
# NOTE: new_height and new_width are further incremented to properly invert the
|
||||
# floordiv operation: https://github.com/huggingface/transformers/blob/v4.42.2/src/transformers/models/llava_next/modeling_llava_next.py#L133
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
height: int,
|
||||
width: int,
|
||||
@@ -56,7 +68,6 @@ def _get_llava_next_num_unpadded_features(
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
@@ -64,9 +75,13 @@ def _get_llava_next_num_unpadded_features(
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
if new_height % 2 == 1:
|
||||
new_height += 1
|
||||
current_height = new_height
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
if new_width % 2 == 1:
|
||||
new_width += 1
|
||||
current_width = new_width
|
||||
|
||||
unpadded_features = current_height * current_width
|
||||
@@ -74,7 +89,8 @@ def _get_llava_next_num_unpadded_features(
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
def _get_llava_next_image_feature_size(
|
||||
# Based on: https://github.com/huggingface/text-generation-inference/blob/v2.0.4/server/text_generation_server/models/vlm_causal_lm.py#L111
|
||||
def get_llava_next_image_feature_size(
|
||||
hf_config: LlavaNextConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
@@ -89,7 +105,9 @@ def _get_llava_next_image_feature_size(
|
||||
)
|
||||
base_feature_size = num_patches * num_patches
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
# Note: We follow the "wrong" width/height order
|
||||
# [ref: PR huggingface/transformers#31588]
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
@@ -110,14 +128,16 @@ def _get_llava_next_image_feature_size(
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_llava_next_image_feature_size(
|
||||
hf_config, input_height=dummy_height, input_width=dummy_width)
|
||||
# Result in the max possible feature size (2x2 grid of 336x336px tiles)
|
||||
dummy_height = dummy_width = 448
|
||||
image_feature_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
@@ -139,27 +159,47 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
|
||||
def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != (image.width, image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
|
||||
image = image.resize((w, h))
|
||||
image_feature_size = get_llava_next_image_feature_size(
|
||||
hf_config,
|
||||
input_height=height,
|
||||
input_width=width,
|
||||
)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin("image") \
|
||||
._default_input_mapper(ctx, image)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
raise TypeError(f"Invalid type for 'image': {type(image)}")
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
vision_config,
|
||||
llm_inputs,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
@@ -172,8 +212,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
text_hidden_size=config.text_config.hidden_size,
|
||||
@@ -196,24 +236,6 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.image_newline = nn.Parameter(
|
||||
torch.empty(config.text_config.hidden_size))
|
||||
|
||||
def _validate_image_pixels(self, data: torch.Tensor) -> torch.Tensor:
|
||||
_, num_channels, _, _ = self.vlm_config.image_input_shape
|
||||
|
||||
# Note that this is different from that of vLLM vision_language_config
|
||||
# since the image is resized by the HuggingFace preprocessor
|
||||
height = width = self.config.vision_config.image_size
|
||||
|
||||
if list(data.shape[2:]) != [num_channels, height, width]:
|
||||
raise ValueError(
|
||||
f"The expected image tensor shape is batch dimension plus "
|
||||
f"num_patches plus {[num_channels, height, width]}. "
|
||||
f"You supplied {data.shape}. "
|
||||
f"If you are using vLLM's entrypoint, make sure your "
|
||||
f"supplied image input is consistent with "
|
||||
f"image_input_shape in engine args.")
|
||||
|
||||
return data
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
@@ -223,14 +245,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
||||
self, **kwargs: object) -> Optional[LlavaNextImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
|
||||
if pixel_values is None or image_sizes is None:
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
@@ -240,7 +262,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_pixels(pixel_values),
|
||||
data=pixel_values,
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
|
||||
@@ -267,15 +289,14 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
strategy=self.config.vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
def _merge_image_patch_embeddings(self, image_size: torch.Tensor,
|
||||
patch_embeddings: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
# Based on: https://github.com/haotian-liu/LLaVA/blob/main/llava/model/llava_arch.py
|
||||
if strategy == "flat":
|
||||
return patch_embeddings.flatten(0, 1)
|
||||
|
||||
if strategy.startswith("spatial"):
|
||||
orig_width, orig_height = image_size
|
||||
height = width = self.config.vision_config.image_size \
|
||||
// self.config.vision_config.patch_size
|
||||
|
||||
@@ -289,13 +310,15 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
other_patch_embeds = patch_embeddings[1:]
|
||||
|
||||
# image_aspect_ratio == "anyres"
|
||||
# Note: We follow the "wrong" width/height order
|
||||
# [ref: PR huggingface/transformers#31588]
|
||||
num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
||||
(orig_width, orig_height),
|
||||
image_size,
|
||||
self.config.image_grid_pinpoints,
|
||||
self.config.vision_config.image_size,
|
||||
)
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
.view(num_patch_width, num_patch_height, height, width, -1)
|
||||
.view(num_patch_height, num_patch_width, height, width, -1)
|
||||
|
||||
if "unpad" in strategy:
|
||||
other_patch_embeds = other_patch_embeds \
|
||||
@@ -333,44 +356,53 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
raise ValueError(f"Unexpected patch merge strategy: {strategy}")
|
||||
|
||||
def _process_image_pixels(
|
||||
self, inputs: LlavaNextImagePixelInputs) -> torch.Tensor:
|
||||
self,
|
||||
inputs: LlavaNextImagePixelInputs,
|
||||
) -> BatchedTensors:
|
||||
assert self.vision_tower is not None
|
||||
|
||||
pixel_values = inputs["data"]
|
||||
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
||||
if isinstance(pixel_values, torch.Tensor):
|
||||
b, num_patches, c, h, w = pixel_values.shape
|
||||
stacked_pixel_values = pixel_values.view(b * num_patches, c, h, w)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
stacked_patch_embeddings = self.multi_modal_projector(
|
||||
stacked_image_features)
|
||||
|
||||
return stacked_patch_embeddings.view(
|
||||
b, num_patches, *stacked_patch_embeddings.shape[1:])
|
||||
|
||||
num_patches_per_batch = [v.shape[0] for v in pixel_values]
|
||||
stacked_pixel_values = torch.cat(pixel_values)
|
||||
stacked_image_features = self._image_pixels_to_features(
|
||||
self.vision_tower, stacked_pixel_values)
|
||||
|
||||
return stacked_image_features.view(b, num_patches,
|
||||
*stacked_image_features.shape[-2:])
|
||||
return [
|
||||
self.multi_modal_projector(image_features) for image_features in
|
||||
torch.split(stacked_image_features, num_patches_per_batch)
|
||||
]
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
patch_embeddings = self.multi_modal_projector(image_features)
|
||||
self, image_input: LlavaNextImageInputs) -> BatchedTensors:
|
||||
patch_embeddings = self._process_image_pixels(image_input)
|
||||
|
||||
image_sizes = image_input.get("image_sizes")
|
||||
if image_sizes is None:
|
||||
batch_size = image_input["data"].shape[0]
|
||||
batch_size = len(image_input["data"])
|
||||
vision_config = self.config.vision_config
|
||||
default_width = default_height = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_width, default_height]
|
||||
default_height = default_width = vision_config.image_size
|
||||
image_sizes = torch.as_tensor([[default_height, default_width]
|
||||
for _ in range(batch_size)])
|
||||
|
||||
merged_patch_embeddings = [
|
||||
return [
|
||||
self._merge_image_patch_embeddings(image_sizes[i],
|
||||
patch_features,
|
||||
patch_features_batch,
|
||||
strategy="spatial_unpad")
|
||||
for i, patch_features in enumerate(patch_embeddings)
|
||||
for i, patch_features_batch in enumerate(patch_embeddings)
|
||||
]
|
||||
|
||||
return torch.stack(merged_patch_embeddings, dim=0)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -404,8 +436,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
input_ids: Flattened (concatenated) input_ids corresponding to a
|
||||
batch.
|
||||
pixel_values: The pixels in each grid patch for each input image.
|
||||
Expects a batch with shape `[1, num_patches, 3, 336, 336]`.
|
||||
image_sizes: The original `(width, height)` for each input image.
|
||||
Expects a batch with shape `[1, num_patches, 3, h, w]`.
|
||||
image_sizes: The original `(height, width)` for each input image.
|
||||
Expects a batch with shape `[1, 2]`.
|
||||
|
||||
See also:
|
||||
|
||||
@@ -13,7 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
import re
|
||||
from functools import lru_cache
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -22,8 +24,8 @@ from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -34,10 +36,12 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
input_processor_for_clip)
|
||||
from .interfaces import SupportsVision
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -251,50 +255,22 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
|
||||
class Phi3VImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, num_channels, height, width)"""
|
||||
data: BatchedTensors
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
"""Shape: (batch_size, 2)"""
|
||||
"""
|
||||
Shape: `(batch_size, 2)`
|
||||
|
||||
This should be in `(height, width)` format.
|
||||
"""
|
||||
|
||||
|
||||
def _get_phi3v_image_feature_size(
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
h, w = input_height, input_width
|
||||
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
|
||||
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
|
||||
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_phi3v_image_feature_size(
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||
top_padding = int((target_height - height) / 2)
|
||||
@@ -304,7 +280,7 @@ def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
|
||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
transposed = False
|
||||
if width < height:
|
||||
@@ -329,27 +305,133 @@ def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
def _image_processor(ctx: InputContext,
|
||||
image: object) -> Dict[str, torch.Tensor]:
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
|
||||
def get_phi3v_image_feature_size(
|
||||
hf_config: PretrainedConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
num_crops = getattr(hf_config, "num_crops", 16)
|
||||
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
||||
height=input_height,
|
||||
hd_num=num_crops)
|
||||
|
||||
if isinstance(image, Image.Image):
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != _calc_hd_transform_size(width=image.width,
|
||||
height=image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
|
||||
image = image.resize((w, h))
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin("image") \
|
||||
._default_input_mapper(ctx, image)
|
||||
raise TypeError(f"Invalid type for 'image': {type(image)}")
|
||||
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
|
||||
+ (new_height // 336 + 1) * 12
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(_image_processor)
|
||||
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
|
||||
# Result in the max possible feature size (h:w = 16:1)
|
||||
dummy_height, dummy_width = 8000, 50
|
||||
image_feature_size = get_phi3v_image_feature_size(
|
||||
ctx.get_hf_config(PretrainedConfig),
|
||||
input_height=dummy_height,
|
||||
input_width=dummy_width,
|
||||
)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
image_token_id=32044,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
# Reserve this function to also handle placeholders for additional images
|
||||
# [ref: PR #5820]
|
||||
@lru_cache
|
||||
def _get_image_placeholder_token_ids(model_config: ModelConfig,
|
||||
idx: int) -> List[int]:
|
||||
assert idx > 0
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
# We need to get the token for "<", not "▁<"
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
|
||||
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
|
||||
a_token_id_, *image_placeholder_token_ids = tokenizer.encode(
|
||||
f"a<|image_{idx}|>", add_special_tokens=False)
|
||||
assert a_token_id == a_token_id_
|
||||
|
||||
return image_placeholder_token_ids
|
||||
|
||||
|
||||
def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
w, h = image_data.size
|
||||
w, h = _calc_hd_transform_size(width=w, height=h)
|
||||
|
||||
image_feature_size = get_phi3v_image_feature_size(hf_config,
|
||||
input_width=w,
|
||||
input_height=h)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
prompt = llm_inputs.get("prompt")
|
||||
if prompt is None:
|
||||
new_prompt = None
|
||||
else:
|
||||
if prompt.count("<|image|>") > 0:
|
||||
logger.warning("Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating <|image|> tokens.")
|
||||
elif len(re.findall(r"(<\|image_\d+\|>)+", prompt)) > 1:
|
||||
logger.warning("Multiple image input is not supported yet, "
|
||||
"so any extra image tokens will be treated "
|
||||
"as plain text.")
|
||||
|
||||
new_prompt = prompt
|
||||
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
image_1_token_ids = _get_image_placeholder_token_ids(model_config, idx=1)
|
||||
|
||||
new_token_ids: List[int] = []
|
||||
for i in range(len(prompt_token_ids) - len(image_1_token_ids) + 1):
|
||||
if prompt_token_ids[i:i + len(image_1_token_ids)] == image_1_token_ids:
|
||||
new_token_ids.append(multimodal_config.image_token_id)
|
||||
|
||||
# No need to further scan the list since we only replace once
|
||||
new_token_ids.extend(prompt_token_ids[i + len(image_1_token_ids):])
|
||||
break
|
||||
else:
|
||||
new_token_ids.append(prompt_token_ids[i])
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
llm_inputs = LLMInputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
return input_processor_for_clip(
|
||||
model_config,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
llm_inputs,
|
||||
image_token_id=multimodal_config.image_token_id,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
@@ -363,6 +445,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
self.model = LlamaModel(config, cache_config, quant_config)
|
||||
|
||||
# TODO: Optionally initializes this for supporting embeddings.
|
||||
self.vision_embed_tokens = Phi3HDImageEmbedding(
|
||||
vlm_config, config, self.model.embed_tokens)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
@@ -376,12 +460,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
|
||||
if pixel_values is not None and image_sizes is not None:
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
return None
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return Phi3VImagePixelInputs(type="pixel_values",
|
||||
data=pixel_values,
|
||||
image_sizes=image_sizes)
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
41
vllm/model_executor/models/utils.py
Normal file
41
vllm/model_executor/models/utils.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import torch
|
||||
|
||||
from vllm.multimodal import BatchedTensors
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds: torch.Tensor,
|
||||
vision_embeddings: BatchedTensors,
|
||||
image_token_id: int) -> torch.Tensor:
|
||||
"""
|
||||
Merge `vision_embeddings` into `inputs_embeds` by overwriting the positions
|
||||
in `inputs_embeds` corresponding to placeholder image tokens in `input_ids`.
|
||||
|
||||
Note:
|
||||
This updates `inputs_embeds` in place.
|
||||
"""
|
||||
mask = (input_ids == image_token_id)
|
||||
num_expected_tokens = mask.sum()
|
||||
|
||||
if isinstance(vision_embeddings, torch.Tensor):
|
||||
batch_size, batch_tokens, *_, embed_dim = vision_embeddings.shape
|
||||
total_tokens = batch_size * batch_tokens
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = f"{batch_size} x {batch_tokens}"
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"image tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = vision_embeddings.view(total_tokens, embed_dim)
|
||||
else:
|
||||
size_per_batch = [t.shape[0] for t in vision_embeddings]
|
||||
total_tokens = sum(size_per_batch)
|
||||
if num_expected_tokens != total_tokens:
|
||||
expr = ' + '.join(map(str, size_per_batch))
|
||||
raise ValueError(
|
||||
f"Attempted to assign {expr} = {total_tokens} "
|
||||
f"image tokens to {num_expected_tokens} placeholders")
|
||||
|
||||
inputs_embeds[mask] = torch.cat(vision_embeddings)
|
||||
|
||||
return inputs_embeds
|
||||
Reference in New Issue
Block a user