[Core] Registry for processing model inputs (#5214)

Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-06-28 20:09:56 +08:00
committed by GitHub
parent 0d0e3a42ac
commit 5cbe8d155c
26 changed files with 784 additions and 398 deletions

View File

@@ -22,7 +22,8 @@ from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.inputs import INPUT_REGISTRY, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
@@ -34,9 +35,10 @@ from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.multimodal.image import ImagePixelData
from vllm.sequence import SamplerOutput
from .clip import dummy_pixel_data_for_clip, dummy_seq_data_for_clip
from .interfaces import SupportsVision
logger = init_logger(__name__)
@@ -107,7 +109,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
self.num_img_tokens = config.img_processor['num_img_tokens']
self.image_dim_out = image_dim_out
self.img_sizes = None
# global_gn and sub_gn for hd transform, serves as line separator
self.use_hd_transform = config.embd_layer.get('use_hd_transform',
@@ -134,7 +135,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
self.img_projection = nn.Sequential(*layers)
self.vocab_size = config.vocab_size
self.img_features = None
self.layer_idx = config.img_processor.get('layer_idx', -2)
self.type_feature = config.img_processor.get('type_feature', 'patch')
@@ -260,9 +260,44 @@ class Phi3VImagePixelInputs(TypedDict):
"""Shape: (batch_size, 2)"""
# FIXME(Isotr0py): Remove these after dynamic num_img_tokens is supported
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_padded_size(width, height, padding_unit=336):
def _get_phi3v_image_feature_size(
*,
input_height: int,
input_width: int,
) -> int:
h, w = input_height, input_width
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L178
return (h // 336 * w // 336 + 1) * 144 + 1 + (h // 336 + 1) * 12
def dummy_data_for_phi3v(ctx: InputContext, seq_len: int):
multimodal_config = ctx.get_multimodal_config()
#TODO: change the logic for dummy data to support dynamic shape
_, _, dummy_height, dummy_width = multimodal_config.image_input_shape
image_feature_size = _get_phi3v_image_feature_size(
input_height=dummy_height,
input_width=dummy_width,
)
seq_data = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
image_token_id=32044,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_pixel_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
image_width_override=dummy_width,
image_height_override=dummy_height,
)
return seq_data, mm_data
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
bottom_padding = target_height - height - top_padding
@@ -271,8 +306,8 @@ def calc_padded_size(width, height, padding_unit=336):
return padded_width, padded_height
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_hd_transform_size(width, height, hd_num=16):
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int = 16):
transposed = False
if width < height:
width, height = height, width
@@ -287,7 +322,8 @@ def calc_hd_transform_size(width, height, hd_num=16):
new_width = int(scale * 336)
new_height = int(new_width / ratio)
padded_width, padded_height = calc_padded_size(new_width, new_height)
padded_width, padded_height = _calc_padded_size(width=new_width,
height=new_height)
if transposed:
padded_width, padded_height = padded_height, padded_width
@@ -295,17 +331,15 @@ def calc_hd_transform_size(width, height, hd_num=16):
return padded_width, padded_height
def _image_processor(
data: ImagePixelData,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Dict[str, torch.Tensor]:
def _image_processor(ctx: InputContext,
data: ImagePixelData) -> Dict[str, torch.Tensor]:
image = data.image
if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = vlm_config.image_input_shape
if (w, h) != calc_hd_transform_size(image.width, image.height):
_, _, h, w = ctx.get_multimodal_config().image_input_shape
if (w, h) != _calc_hd_transform_size(width=image.width,
height=image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
@@ -313,11 +347,11 @@ def _image_processor(
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_processor(data, model_config, vlm_config)
._default_input_mapper(ctx, data)
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
@MULTIMODAL_REGISTRY.register_image_pixel_input_mapper(_image_processor)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
class Phi3VForCausalLM(nn.Module, SupportsVision):
def __init__(self,