[Core] Registry for processing model inputs (#5214)

Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-06-28 20:09:56 +08:00
committed by GitHub
parent 0d0e3a42ac
commit 5cbe8d155c
26 changed files with 784 additions and 398 deletions

View File

@@ -3,14 +3,14 @@ from typing import (Dict, Iterable, List, Literal, Optional, Tuple, TypedDict,
import torch
import torch.nn as nn
from PIL import Image
from transformers import LlavaNextConfig
from transformers import CLIPVisionConfig, LlavaNextConfig
from transformers.models.llava_next.modeling_llava_next import (
get_anyres_image_grid_shape, unpad_image)
from typing_extensions import NotRequired
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@@ -22,9 +22,11 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalData
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput, SequenceData
from vllm.multimodal.image import ImagePixelData
from vllm.sequence import SamplerOutput
from .clip import (dummy_feature_data_for_clip, dummy_pixel_data_for_clip,
dummy_seq_data_for_clip, get_clip_patch_grid_length)
from .interfaces import SupportsVision
from .llava import LlavaMultiModalProjector, merge_vision_embeddings
@@ -58,41 +60,118 @@ LlavaNextImageInputs = Union[LlavaNextImagePixelInputs,
LlavaNextImageFeatureInputs]
def _get_dummy_image_data(
seq_len: int,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Tuple[SequenceData, MultiModalData]:
seq_data, fake_mm_data = get_dummy_image_data(seq_len, model_config,
vlm_config)
def _get_llava_next_num_unpadded_features(
height: int,
width: int,
npatches: int,
num_patch_height: int,
num_patch_width: int,
) -> Tuple[int, int]:
# Taken from: https://github.com/huggingface/text-generation-inference/blob/799a193b109662743bed1b18a09af1fdcd508c8b/server/text_generation_server/models/vlm_causal_lm.py#L111
current_height = npatches * num_patch_height
current_width = npatches * num_patch_width
config_input_type = vlm_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
aspect_ratio: float = width / height
current_aspect_ratio: float = current_width / current_height
if aspect_ratio > current_aspect_ratio:
new_height = (height * current_width) // width
current_height = new_height
else:
new_width = (width * current_height) // height
current_width = new_width
if config_input_type == ImageInputType.PIXEL_VALUES:
_, c, h, w = vlm_config.image_input_shape
mode = {1: "L", 3: "RGB"}[c]
fake_mm_data = ImagePixelData(Image.new(mode, (w, h), color=0))
return seq_data, fake_mm_data
unpadded_features = current_height * current_width
newline_features = current_height
return (unpadded_features, newline_features)
def _image_pixel_processor(
data: ImagePixelData,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Dict[str, torch.Tensor]:
def _get_llava_next_image_feature_size(
hf_config: LlavaNextConfig,
*,
input_height: int,
input_width: int,
) -> int:
vision_config = hf_config.vision_config
if isinstance(vision_config, CLIPVisionConfig):
num_patches = get_clip_patch_grid_length(
image_size=vision_config.image_size,
patch_size=vision_config.patch_size,
)
base_feature_size = num_patches * num_patches
num_patch_height, num_patch_width = get_anyres_image_grid_shape(
image_size=(input_height, input_width),
grid_pinpoints=hf_config.image_grid_pinpoints,
patch_size=vision_config.image_size,
)
(
unpadded_feature_size,
newline_feature_size,
) = _get_llava_next_num_unpadded_features(input_height, input_width,
num_patches,
num_patch_height,
num_patch_width)
return unpadded_feature_size + newline_feature_size + base_feature_size
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def dummy_data_for_llava_next(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
hf_config = ctx.get_hf_config(LlavaNextConfig)
vision_config = hf_config.vision_config
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_llava_next_image_feature_size(
hf_config, input_height=dummy_height, input_width=dummy_width)
if isinstance(vision_config, CLIPVisionConfig):
seq_data = dummy_seq_data_for_clip(
vision_config,
seq_len,
image_token_id=hf_config.image_token_index,
image_feature_size_override=image_feature_size,
)
image_input_type = multimodal_config.image_input_type
ImageInputType = VisionLanguageConfig.ImageInputType
mm_data: MultiModalData
if image_input_type == ImageInputType.PIXEL_VALUES:
mm_data = dummy_pixel_data_for_clip(
vision_config,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
elif image_input_type == ImageInputType.IMAGE_FEATURES:
mm_data = dummy_feature_data_for_clip(
vision_config,
image_feature_size_override=image_feature_size,
)
return seq_data, mm_data
msg = f"Unsupported vision config: {type(vision_config)}"
raise NotImplementedError(msg)
def _pixel_mapper(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
if isinstance(image, torch.Tensor):
pixel_values = image.to(model_config.dtype)
pixel_values = image.to(ctx.model_config.dtype)
batch_size, _, _, h, w = pixel_values.shape
image_sizes = torch.tensor([(w, h) for _ in range(batch_size)])
return {"pixel_values": pixel_values, "image_sizes": image_sizes}
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = vlm_config.image_input_shape
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != (image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
@@ -101,11 +180,12 @@ def _image_pixel_processor(
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_processor(data, model_config, vlm_config)
._default_input_mapper(ctx, data)
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_pixel_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(_get_dummy_image_data)
@MULTIMODAL_REGISTRY.register_image_feature_input_mapper()
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_pixel_mapper)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next)
class LlavaNextForConditionalGeneration(nn.Module, SupportsVision):
def __init__(self,