[Core] Registry for processing model inputs (#5214)
Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
@@ -3,14 +3,14 @@ from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import LlavaNextConfig
|
||||
from transformers import CLIPVisionConfig, LlavaNextConfig
|
||||
from transformers.models.llava_next.modeling_llava_next import (
|
||||
get_anyres_image_grid_shape, unpad_image)
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
|
||||
from vllm.config import CacheConfig, VisionLanguageConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@@ -22,9 +22,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.models.llama import LlamaModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
|
||||
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
|
||||
from vllm.sequence import SamplerOutput, SequenceData
|
||||
from vllm.multimodal.image import ImagePixelData
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
|
||||
dummy_seq_data_for_clip, get_clip_patch_grid_length)
|
||||
from .interfaces import SupportsVision
|
||||
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
|
||||
|
||||
@@ -58,41 +60,118 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
|
||||
LlavaNextImageFeatureInputs]
|
||||
|
||||
|
||||
def _get_dummy_image_data(
|
||||
seq_len: int,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Tuple[SequenceData, MultiModalData]:
|
||||
seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config,
|
||||
vlm_config)
|
||||
def _get_llava_next_num_unpadded_features(
|
||||
height: int,
|
||||
width: int,
|
||||
npatches: int,
|
||||
num_patch_height: int,
|
||||
num_patch_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
|
||||
current_height = npatches * num_patch_height
|
||||
current_width = npatches * num_patch_width
|
||||
|
||||
config_input_type = vlm_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
aspect_ratio: float = width / height
|
||||
current_aspect_ratio: float = current_width / current_height
|
||||
if aspect_ratio > current_aspect_ratio:
|
||||
new_height = (height * current_width) // width
|
||||
current_height = new_height
|
||||
else:
|
||||
new_width = (width * current_height) // height
|
||||
current_width = new_width
|
||||
|
||||
if config_input_type == ImageInputType.PIXEL_VALUES:
|
||||
_, c, h, w = vlm_config.image_input_shape
|
||||
mode = {1: "L", 3: "RGB"}[c]
|
||||
fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0))
|
||||
|
||||
return seq_data, fake_mm_data
|
||||
unpadded_features = current_height * current_width
|
||||
newline_features = current_height
|
||||
return (unpadded_features, newline_features)
|
||||
|
||||
|
||||
def _image_pixel_processor(
|
||||
data: ImagePixelData,
|
||||
model_config: ModelConfig,
|
||||
vlm_config: VisionLanguageConfig,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
def _get_llava_next_image_feature_size(
|
||||
hf_config: LlavaNextConfig,
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
) -> int:
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
num_patches = get_clip_patch_grid_length(
|
||||
image_size=vision_config.image_size,
|
||||
patch_size=vision_config.patch_size,
|
||||
)
|
||||
base_feature_size = num_patches * num_patches
|
||||
|
||||
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
|
||||
image_size=(input_height, input_width),
|
||||
grid_pinpoints=hf_config.image_grid_pinpoints,
|
||||
patch_size=vision_config.image_size,
|
||||
)
|
||||
|
||||
(
|
||||
unpadded_feature_size,
|
||||
newline_feature_size,
|
||||
) = _get_llava_next_num_unpadded_features(input_height, input_width,
|
||||
num_patches,
|
||||
num_patch_height,
|
||||
num_patch_width)
|
||||
|
||||
return unpadded_feature_size + newline_feature_size + base_feature_size
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
|
||||
multimodal_config = ctx.get_multimodal_config()
|
||||
hf_config = ctx.get_hf_config(LlavaNextConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
#TODO: change the logic for dummy data to support dynamic shape
|
||||
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
|
||||
image_feature_size = _get_llava_next_image_feature_size(
|
||||
hf_config, input_height=dummy_height, input_width=dummy_width)
|
||||
|
||||
if isinstance(vision_config, CLIPVisionConfig):
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
image_token_id=hf_config.image_token_index,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
image_input_type = multimodal_config.image_input_type
|
||||
ImageInputType = VisionLanguageConfig.ImageInputType
|
||||
mm_data: MultiModalData
|
||||
if image_input_type == ImageInputType.PIXEL_VALUES:
|
||||
mm_data = dummy_pixel_data_for_clip(
|
||||
vision_config,
|
||||
image_width_override=dummy_width,
|
||||
image_height_override=dummy_height,
|
||||
)
|
||||
elif image_input_type == ImageInputType.IMAGE_FEATURES:
|
||||
mm_data = dummy_feature_data_for_clip(
|
||||
vision_config,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
msg = f"Unsupported vision config: {type(vision_config)}"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def _pixel_mapper(ctx: InputContext,
|
||||
data: ImagePixelData) -> Dict[str, torch.Tensor]:
|
||||
image = data.image
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
pixel_values = image.to(model_config.dtype)
|
||||
pixel_values = image.to(ctx.model_config.dtype)
|
||||
batch_size, _, _, h, w = pixel_values.shape
|
||||
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
|
||||
|
||||
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
|
||||
|
||||
# Temporary patch before dynamic number of image tokens is supported
|
||||
_, _, h, w = vlm_config.image_input_shape
|
||||
_, _, h, w = ctx.get_multimodal_config().image_input_shape
|
||||
if (w, h) != (image.width, image.height):
|
||||
logger.warning(
|
||||
"Dynamic image shape is currently not supported. "
|
||||
@@ -101,11 +180,12 @@ def _image_pixel_processor(
|
||||
data.image = image.resize((w, h))
|
||||
|
||||
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
|
||||
._default_input_processor(data, model_config, vlm_config)
|
||||
._default_input_mapper(ctx, data)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
|
||||
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
|
||||
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
|
||||
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
|
||||
Reference in New Issue
Block a user