[V1][VLM] V1 support for selected single-image models. (#11632)
Signed-off-by: Roger Wang <ywang@roblox.com> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -15,32 +15,30 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Fuyu model."""
|
||||
import math
|
||||
from array import array
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from PIL import Image
|
||||
from transformers import FuyuImageProcessor
|
||||
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
|
||||
FuyuProcessor)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.model_executor.layers.linear import ColumnParallelLinear
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import NestedTensors
|
||||
from vllm.multimodal.utils import (cached_get_tokenizer,
|
||||
consecutive_placeholder_ranges)
|
||||
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
|
||||
SequenceData)
|
||||
from vllm.utils import is_list_of
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import ImageProcessorItems
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
MultiModalDataItems, ProcessorInputs,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
|
||||
@@ -54,178 +52,193 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
|
||||
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
|
||||
|
||||
|
||||
class FuyuImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
class FuyuImagePatchInputs(TypedDict):
|
||||
type: Literal["image_patches"]
|
||||
data: torch.Tensor
|
||||
"""
|
||||
Shape:
|
||||
(batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
|
||||
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
|
||||
"""
|
||||
|
||||
patches_per_image: List[int]
|
||||
"""
|
||||
List of number of total patches for each image in the batch.
|
||||
This is used to restore the first two dimensions of `data`.
|
||||
"""
|
||||
|
||||
|
||||
def _calculate_num_image_tokens(
|
||||
height: int,
|
||||
width: int,
|
||||
def _get_fuyu_num_image_tokens(
|
||||
image_height: int,
|
||||
image_width: int,
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
calculate number of image tokens needed for a given image size
|
||||
The expected Fuyu image prompts is in format:
|
||||
Calculate the number of image tokens needed for a given image size.
|
||||
|
||||
The expected Fuyu image prompts can be expressed as:
|
||||
|
||||
.. code-block::
|
||||
(image_token * ncols + newline_token) * nrows
|
||||
args:
|
||||
image_size: Tuple[int, int] - (width, height) of the image
|
||||
returns:
|
||||
ncols: int - number of image tokens in x direction
|
||||
nrows: int - number of image tokens in y direction
|
||||
|
||||
Args:
|
||||
image_size: Tuple[int, int] - `(width, height)` of the image
|
||||
|
||||
Returns:
|
||||
ncols: int - number of image tokens in `x` direction
|
||||
nrows: int - number of image tokens in `y` direction
|
||||
"""
|
||||
ncol = math.ceil(width / 30)
|
||||
nrow = math.ceil(height / 30)
|
||||
return ncol, nrow
|
||||
|
||||
|
||||
def get_max_fuyu_image_feature_size():
|
||||
|
||||
return _calculate_num_image_tokens(
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
ncols = math.ceil(image_width / 30)
|
||||
nrows = math.ceil(image_height / 30)
|
||||
return ncols, nrows
|
||||
|
||||
|
||||
def get_max_fuyu_image_tokens(ctx: InputContext):
|
||||
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||
return (ncol + 1) * nrow
|
||||
|
||||
|
||||
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
|
||||
ncol, nrow = get_max_fuyu_image_feature_size()
|
||||
image_feature_size = get_max_fuyu_image_tokens(ctx)
|
||||
|
||||
image_token_ids = (
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
|
||||
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
|
||||
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
|
||||
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
||||
[0]) * (seq_len - image_feature_size * num_images)
|
||||
return SequenceData(token_ids), {
|
||||
"image":
|
||||
consecutive_placeholder_ranges(num_items=num_images,
|
||||
item_size=image_feature_size)
|
||||
}
|
||||
|
||||
|
||||
def dummy_image_for_fuyu(
|
||||
num_images: int,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
):
|
||||
image = Image.new("RGB", (image_width, image_height), color=0)
|
||||
return {"image": image if num_images == 1 else [image] * num_images}
|
||||
|
||||
|
||||
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
|
||||
mm_data = dummy_image_for_fuyu(num_images,
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
|
||||
|
||||
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
|
||||
data: List[Image.Image]):
|
||||
image_encoding = image_processor.preprocess(data, return_tensors="pt")
|
||||
batch_images = torch.stack([img[0] for img in image_encoding["images"]
|
||||
]).unsqueeze(1)
|
||||
image_unpadded_heights = torch.tensor(
|
||||
image_encoding["image_unpadded_heights"])
|
||||
image_unpadded_widths = torch.tensor(
|
||||
image_encoding["image_unpadded_widths"])
|
||||
|
||||
batch_size = len(image_encoding["images"])
|
||||
image_present = torch.ones(batch_size, 1, 1)
|
||||
model_image_input = image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_placeholder_id=_IMAGE_TOKEN_ID,
|
||||
image_newline_id=_NEWLINE_TOKEN_ID,
|
||||
variable_sized=True,
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
)
|
||||
return model_image_input
|
||||
|
||||
return (ncols + 1) * nrows
|
||||
|
||||
|
||||
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
|
||||
|
||||
model_config = ctx.model_config
|
||||
image_data = multi_modal_data["image"]
|
||||
new_multi_modal_data = {}
|
||||
image_list = image_data if isinstance(image_data, list) else [image_data]
|
||||
def _get_hf_processor(self) -> FuyuProcessor:
|
||||
return self.ctx.get_hf_processor(FuyuProcessor)
|
||||
|
||||
# process image data
|
||||
if is_list_of(image_list, Image.Image):
|
||||
# Fuyu's image_processor can also finish token padding
|
||||
image_processor: FuyuImageProcessor = cached_get_image_processor(
|
||||
model_config.model)
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
|
||||
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
|
||||
image_patches = torch.cat([
|
||||
image_patch[0]
|
||||
for image_patch in model_image_input["image_patches"]
|
||||
])
|
||||
new_multi_modal_data["image"] = image_patches
|
||||
if not mm_data:
|
||||
# Avoid warning from HF logger for text-only input
|
||||
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
|
||||
# Tokenizer won't add boa_token_id by default, we add it manually.
|
||||
tokenizer = self._get_tokenizer()
|
||||
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
|
||||
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
elif is_list_of(image_list, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
# process prompts
|
||||
prompt = inputs.get("prompt")
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
tokenizer = cached_get_tokenizer(model_config.model)
|
||||
# dim0 is batch_size, dim1 is subseq_size which will always be 1
|
||||
image_input_ids: List[List[
|
||||
torch.Tensor]] = model_image_input["image_input_ids"]
|
||||
image_input_ids = image_input_ids[0][0].tolist()
|
||||
bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
|
||||
boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
|
||||
image_patches = processed_outputs.get("image_patches")
|
||||
if image_patches is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
new_prompt = prompt + "\x04"
|
||||
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
|
||||
1:] + boa_token
|
||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||
# New output: (num_images, Pn, Px * Py * C)
|
||||
assert (isinstance(image_patches, list)
|
||||
and len(image_patches) == 1)
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
|
||||
return token_inputs(prompt=new_prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=new_multi_modal_data)
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(image_patches=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.ctx.get_hf_config(FuyuConfig)
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
|
||||
tokenizer = self._get_tokenizer()
|
||||
eot_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(eot_token_id, int)
|
||||
|
||||
hf_processor = self._get_hf_processor()
|
||||
image_processor: FuyuImageProcessor = hf_processor.image_processor
|
||||
target_size = image_processor.size
|
||||
target_height, target_width = (target_size["height"],
|
||||
target_size["width"])
|
||||
|
||||
def get_replacement_fuyu(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
width, height = image_size.width, image_size.height
|
||||
if not (width <= target_width and height <= target_height):
|
||||
height_scale_factor = target_height / height
|
||||
width_scale_factor = target_width / width
|
||||
optimal_scale_factor = min(height_scale_factor,
|
||||
width_scale_factor)
|
||||
|
||||
height = int(height * optimal_scale_factor)
|
||||
width = int(width * optimal_scale_factor)
|
||||
|
||||
ncols, nrows = _get_fuyu_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
)
|
||||
|
||||
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
|
||||
[bos_token_id])
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[eot_token_id],
|
||||
replacement=get_replacement_fuyu,
|
||||
)
|
||||
]
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt_text: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only |SPEAKER| (image) tokens should be considered as placeholders,
|
||||
# so we ignore the trailing bos_token_id
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
|
||||
for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def _get_dummy_mm_inputs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
mm_data = {
|
||||
"image":
|
||||
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
num_images=num_images)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="",
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
def input_mapper_for_fuyu(ctx: InputContext, data: object):
|
||||
model_config = ctx.model_config
|
||||
data_list = data if isinstance(data, list) else [data]
|
||||
if is_list_of(data_list, Image.Image):
|
||||
# Fuyu's image_processor can also finish token padding
|
||||
image_processor: FuyuImageProcessor = cached_get_image_processor(
|
||||
model_config.model)
|
||||
|
||||
model_image_input = _fuyu_image_preprocess(image_processor, data_list)
|
||||
data = torch.stack([
|
||||
image_patch[0]
|
||||
for image_patch in model_image_input["image_patches"]
|
||||
])
|
||||
|
||||
# image has been processed with prompt in input processor
|
||||
return MultiModalKwargs({"pixel_values": data})
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
|
||||
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
|
||||
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
@@ -280,28 +293,32 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return data.to(self.vision_embed_tokens.weight.dtype)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
if image_patches is not None:
|
||||
if not isinstance(image_patches, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of image patches. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
f"Got type: {type(image_patches)}")
|
||||
|
||||
return FuyuImagePixelInputs(
|
||||
type="pixel_values",
|
||||
image_patches_flat = flatten_bn(image_patches)
|
||||
|
||||
return FuyuImagePatchInputs(
|
||||
type="image_patches",
|
||||
data=self._validate_pixel_values(
|
||||
flatten_bn(pixel_values, concat=True)),
|
||||
flatten_bn(image_patches_flat, concat=True)),
|
||||
patches_per_image=[x.size(0) for x in image_patches_flat],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
|
||||
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
|
||||
image_patches = image_input["data"]
|
||||
patches_per_image = image_input["patches_per_image"]
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
|
||||
return vision_embeddings
|
||||
vision_embeddings, _ = self.vision_embed_tokens(image_patches)
|
||||
return vision_embeddings.split(patches_per_image, dim=0)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
Reference in New Issue
Block a user