[VLM] Remove image_input_type from VLM config (#5852)

Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
xwjiang2010
2024-07-02 00:57:09 -07:00
committed by GitHub
parent 2c37540aa6
commit 98d6682cd1
35 changed files with 329 additions and 751 deletions

View File

@@ -1,8 +1,8 @@
from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
Union)
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
@@ -21,12 +21,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip, get_clip_patch_grid_length)
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
get_clip_patch_grid_length)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
@@ -47,17 +46,7 @@ class LlavaNextImagePixelInputs(TypedDict):
"""Shape: (batch_size, 2)"""
class LlavaNextImageFeatureInputs(TypedDict):
type: Literal["image_features"]
data: torch.Tensor
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
image_sizes: NotRequired[torch.Tensor]
"""Shape: (batch_size, 2)"""
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageFeatureInputs]
LlavaNextImageInputs = LlavaNextImagePixelInputs
def _get_llava_next_num_unpadded_features(
@@ -138,20 +127,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
image_feature_size_override=image_feature_size,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(
vision_config,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
@@ -159,32 +139,26 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
if isinstance(image, torch.Tensor):
pixel_values = image.to(ctx.model_config.dtype)
batch_size, _, _, h, w = pixel_values.shape
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
if isinstance(image, Image.Image):
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
image = image.resize((w, h))
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin("image") \
._default_input_mapper(ctx, image)
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_mapper(ctx, data)
raise TypeError(f"Invalid type for 'image': {type(image)}")
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
@@ -198,11 +172,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.vlm_config = vlm_config
if self.vlm_config.image_input_type == (
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
self.vision_tower = CLIPVisionModel(config=config.vision_config)
else:
raise TypeError("Image features are not supported by LLaVA-NeXT")
self.vision_tower = CLIPVisionModel(config=config.vision_config)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
@@ -255,36 +225,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_features = kwargs.pop("image_features", None)
expected_input_type = self.vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
if pixel_values is None or image_sizes is None:
return None
if expected_input_type == ImageInputType.PIXEL_VALUES:
if image_features is not None:
raise ValueError(
"Expected pixel values but got image features")
if pixel_values is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
"Failed to validate this at initialization time")
return None
return LlavaNextImagePixelInputs(
type="pixel_values",
data=self._validate_image_pixels(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes),
)
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
@@ -391,11 +348,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
if image_input["type"] == "pixel_values":
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
else:
image_features = image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
patch_embeddings = self.multi_modal_projector(image_features)