[VLM] Remove image_input_type from VLM config (#5852)
Signed-off-by: Xiaowei Jiang <xwjiang2010@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
|
||||
Union)
|
||||
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig, LlavaNextConfig
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
@@ -21,12 +21,11 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
|
||||
dummy_seq_data_for_clip, get_clip_patch_grid_length)
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_patch_grid_length)
|
||||
from .interfaces import SupportsVision
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
|
||||
@@ -47,17 +46,7 @@ class LlavaNextImagePixelInputs(TypedDict):
|
||||
"""Shape: (batch_size, 2)"""
|
||||
|
||||
|
||||
class LlavaNextImageFeatureInputs(TypedDict):
|
||||
type: Literal["image_features"]
|
||||
data: torch.Tensor
|
||||
"""Shape: (batch_size, 1 + num_patches, image_feature_size, hidden_size)"""
|
||||
|
||||
image_sizes: NotRequired[torch.Tensor]
|
||||
"""Shape: (batch_size, 2)"""
|
||||
|
||||
|
||||
LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageFeatureInputs]
|
||||
LlavaNextImageInputs = LlavaNextImagePixelInputs
|
||||
|
||||
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
@@ -138,20 +127,11 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
image_input_type = multimodal_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
mm_data: MultiModalData
|
||||
if image_input_type == ImageInputType.PIXEL_VALUES:
|
||||
mm_data = dummy_pixel_data_for_clip(
|
||||
vision_config,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
elif image_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
mm_data = dummy_feature_data_for_clip(
|
||||
vision_config,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
vision_config,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
@@ -159,32 +139,26 @@ def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _pixel_mapper(ctx: InputContext,
|
||||
data: ImagePixelData) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
def _pixel_mapper(ctx: InputContext, image: object) -> Dict[str, torch.Tensor]:
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
pixel_values = image.to(ctx.model_config.dtype)
|
||||
batch_size, _, _, h, w = pixel_values.shape
|
||||
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
|
||||
if isinstance(image, Image.Image):
|
||||
|
||||
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != (image.width, image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != (image.width, image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
"Resizing input image to (%d, %d).", w, h)
|
||||
image = image.resize((w, h))
|
||||
|
||||
data.image = image.resize((w, h))
|
||||
return MULTIMODAL_REGISTRY._get_plugin("image") \
|
||||
._default_input_mapper(ctx, image)
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
|
||||
._default_input_mapper(ctx, data)
|
||||
raise TypeError(f"Invalid type for 'image': {type(image)}")
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(_pixel_mapper)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
@@ -198,11 +172,7 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self.config = config
|
||||
self.vlm_config = vlm_config
|
||||
|
||||
if self.vlm_config.image_input_type == (
|
||||
VisionLanguageConfig.ImageInputType.PIXEL_VALUES):
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
else:
|
||||
raise TypeError("Image features are not supported by LLaVA-NeXT")
|
||||
self.vision_tower = CLIPVisionModel(config=config.vision_config)
|
||||
|
||||
self.multi_modal_projector = LlavaMultiModalProjector(
|
||||
vision_hidden_size=config.vision_config.hidden_size,
|
||||
@@ -255,36 +225,23 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
self, **kwargs: object) -> Optional[LlavaNextImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_sizes = kwargs.pop("image_sizes", None)
|
||||
image_features = kwargs.pop("image_features", None)
|
||||
|
||||
expected_input_type = self.vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
if pixel_values is None or image_sizes is None:
|
||||
return None
|
||||
|
||||
if expected_input_type == ImageInputType.PIXEL_VALUES:
|
||||
if image_features is not None:
|
||||
raise ValueError(
|
||||
"Expected pixel values but got image features")
|
||||
if pixel_values is None:
|
||||
return None
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
if not isinstance(image_sizes, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image sizes. "
|
||||
f"Got type: {type(image_sizes)}")
|
||||
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_pixels(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
|
||||
assert expected_input_type != ImageInputType.IMAGE_FEATURES, (
|
||||
"Failed to validate this at initialization time")
|
||||
|
||||
return None
|
||||
return LlavaNextImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_image_pixels(pixel_values),
|
||||
image_sizes=self._validate_image_sizes(image_sizes),
|
||||
)
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
@@ -391,11 +348,8 @@ class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: LlavaNextImageInputs) -> torch.Tensor:
|
||||
if image_input["type"] == "pixel_values":
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
else:
|
||||
image_features = image_input["data"]
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
|
||||
patch_embeddings = self.multi_modal_projector(image_features)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user