[Core] Registry for processing model inputs (#5214)
Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
@@ -1,22 +1,83 @@
|
||||
"""Minimal implementation of CLIPVisionModel intended to be only used
|
||||
within a vision language model."""
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig
|
||||
from transformers.models.clip.modeling_clip import CLIPAttention
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.multimodal.image import ImageFeatureData, ImagePixelData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
def get_clip_num_patches(image_size: int, patch_size: int) -> int:
|
||||
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:
|
||||
assert image_size % patch_size == 0
|
||||
return (image_size // patch_size)**2
|
||||
return image_size // patch_size
|
||||
|
||||
|
||||
def get_clip_num_patches(*, image_size: int, patch_size: int) -> int:
|
||||
grid_length = get_clip_patch_grid_length(image_size=image_size,
|
||||
patch_size=patch_size)
|
||||
return grid_length * grid_length
|
||||
|
||||
|
||||
def get_clip_image_feature_size(hf_config: CLIPVisionConfig) -> int:
|
||||
return get_clip_num_patches(image_size=hf_config.image_size,
|
||||
patch_size=hf_config.patch_size)
|
||||
|
||||
|
||||
def dummy_seq_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
seq_len: int,
|
||||
*,
|
||||
image_token_id: int,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
token_ids = [image_token_id] * image_feature_size
|
||||
token_ids += [0] * (seq_len - image_feature_size)
|
||||
return SequenceData(token_ids)
|
||||
|
||||
|
||||
def dummy_pixel_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
*,
|
||||
image_width_override: Optional[int] = None,
|
||||
image_height_override: Optional[int] = None,
|
||||
):
|
||||
width = height = hf_config.image_size
|
||||
if image_width_override is not None:
|
||||
width = image_width_override
|
||||
if image_height_override is not None:
|
||||
height = image_height_override
|
||||
|
||||
image = Image.new("RGB", (width, height), color=0)
|
||||
return ImagePixelData(image)
|
||||
|
||||
|
||||
def dummy_feature_data_for_clip(
|
||||
hf_config: CLIPVisionConfig,
|
||||
*,
|
||||
image_feature_size_override: Optional[int] = None,
|
||||
):
|
||||
if image_feature_size_override is None:
|
||||
image_feature_size = get_clip_image_feature_size(hf_config)
|
||||
else:
|
||||
image_feature_size = image_feature_size_override
|
||||
|
||||
values = torch.zeros((1, image_feature_size, hf_config.hidden_size),
|
||||
dtype=torch.float16)
|
||||
return ImageFeatureData(values)
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/clip/modeling_clip.py#L164 # noqa
|
||||
@@ -39,8 +100,8 @@ class CLIPVisionEmbeddings(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.num_patches = get_clip_num_patches(self.image_size,
|
||||
self.patch_size)
|
||||
self.num_patches = get_clip_num_patches(image_size=self.image_size,
|
||||
patch_size=self.patch_size)
|
||||
self.num_positions = self.num_patches + 1
|
||||
self.position_embedding = nn.Embedding(self.num_positions,
|
||||
self.embed_dim)
|
||||
@@ -101,7 +162,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
self.layer_norm2 = nn.LayerNorm(config.hidden_size,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user