[V1][VLM] V1 support for selected single-image models. (#11632)

Signed-off-by: Roger Wang <ywang@roblox.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Roger Wang
2024-12-31 13:17:22 -08:00
committed by GitHub
parent 8c3230d8c1
commit e7c7c5e822
19 changed files with 575 additions and 621 deletions

View File

@@ -15,32 +15,30 @@
# limitations under the License.
""" PyTorch Fuyu model."""
import math
from array import array
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict)
import torch
import torch.nn as nn
import torch.utils.checkpoint
from PIL import Image
from transformers import FuyuImageProcessor
from transformers import (BatchFeature, FuyuConfig, FuyuImageProcessor,
FuyuProcessor)
from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.inputs import InputContext
from vllm.model_executor.layers.linear import ColumnParallelLinear
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.models.persimmon import PersimmonForCausalLM
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import NestedTensors
from vllm.multimodal.utils import (cached_get_tokenizer,
consecutive_placeholder_ranges)
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.utils import is_list_of
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import ImageProcessorItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
MultiModalDataItems, ProcessorInputs,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, flatten_bn, maybe_prefix,
@@ -54,178 +52,193 @@ MAX_IMAGE_FEATURE_SIZE_HEIGHT = 1080
MAX_IMAGE_FEATURE_SIZE_WIDTH = 1920
class FuyuImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
class FuyuImagePatchInputs(TypedDict):
type: Literal["image_patches"]
data: torch.Tensor
"""
Shape:
(batch_size, num_patches, patch_size_x * patch_size_y * num_channels)
`(batch_size * num_patches, patch_size_x * patch_size_y * num_channels)`
"""
patches_per_image: List[int]
"""
List of number of total patches for each image in the batch.
This is used to restore the first two dimensions of `data`.
"""
def _calculate_num_image_tokens(
height: int,
width: int,
def _get_fuyu_num_image_tokens(
image_height: int,
image_width: int,
) -> Tuple[int, int]:
"""
calculate number of image tokens needed for a given image size
The expected Fuyu image prompts is in format:
Calculate the number of image tokens needed for a given image size.
The expected Fuyu image prompts can be expressed as:
.. code-block::
(image_token * ncols + newline_token) * nrows
args:
image_size: Tuple[int, int] - (width, height) of the image
returns:
ncols: int - number of image tokens in x direction
nrows: int - number of image tokens in y direction
Args:
image_size: Tuple[int, int] - `(width, height)` of the image
Returns:
ncols: int - number of image tokens in `x` direction
nrows: int - number of image tokens in `y` direction
"""
ncol = math.ceil(width / 30)
nrow = math.ceil(height / 30)
return ncol, nrow
def get_max_fuyu_image_feature_size():
return _calculate_num_image_tokens(
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
ncols = math.ceil(image_width / 30)
nrows = math.ceil(image_height / 30)
return ncols, nrows
def get_max_fuyu_image_tokens(ctx: InputContext):
ncol, nrow = get_max_fuyu_image_feature_size()
return (ncol + 1) * nrow
def dummy_seq_data_for_fuyu(ctx: InputContext, seq_len: int, num_images: int):
ncol, nrow = get_max_fuyu_image_feature_size()
image_feature_size = get_max_fuyu_image_tokens(ctx)
image_token_ids = (
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_IMAGE_TOKEN_ID]) * ncol +
array(VLLM_TOKEN_ID_ARRAY_TYPE, [_NEWLINE_TOKEN_ID])) * nrow
token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, image_token_ids) * num_images
token_ids += array(VLLM_TOKEN_ID_ARRAY_TYPE,
[0]) * (seq_len - image_feature_size * num_images)
return SequenceData(token_ids), {
"image":
consecutive_placeholder_ranges(num_items=num_images,
item_size=image_feature_size)
}
def dummy_image_for_fuyu(
num_images: int,
*,
image_width: int,
image_height: int,
):
image = Image.new("RGB", (image_width, image_height), color=0)
return {"image": image if num_images == 1 else [image] * num_images}
def dummy_data_for_fuyu(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
seq_data, ranges = dummy_seq_data_for_fuyu(ctx, seq_len, num_images)
mm_data = dummy_image_for_fuyu(num_images,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT)
return DummyData(seq_data, mm_data, ranges)
def _fuyu_image_preprocess(image_processor: FuyuImageProcessor,
data: List[Image.Image]):
image_encoding = image_processor.preprocess(data, return_tensors="pt")
batch_images = torch.stack([img[0] for img in image_encoding["images"]
]).unsqueeze(1)
image_unpadded_heights = torch.tensor(
image_encoding["image_unpadded_heights"])
image_unpadded_widths = torch.tensor(
image_encoding["image_unpadded_widths"])
batch_size = len(image_encoding["images"])
image_present = torch.ones(batch_size, 1, 1)
model_image_input = image_processor.preprocess_with_tokenizer_info(
image_input=batch_images,
image_present=image_present,
image_unpadded_h=image_unpadded_heights,
image_unpadded_w=image_unpadded_widths,
image_placeholder_id=_IMAGE_TOKEN_ID,
image_newline_id=_NEWLINE_TOKEN_ID,
variable_sized=True,
ncols, nrows = _get_fuyu_num_image_tokens(
image_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
image_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
)
return model_image_input
return (ncols + 1) * nrows
def input_processor_for_fuyu(ctx: InputContext, inputs: DecoderOnlyInputs):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
class FuyuMultiModalProcessor(BaseMultiModalProcessor):
model_config = ctx.model_config
image_data = multi_modal_data["image"]
new_multi_modal_data = {}
image_list = image_data if isinstance(image_data, list) else [image_data]
def _get_hf_processor(self) -> FuyuProcessor:
return self.ctx.get_hf_processor(FuyuProcessor)
# process image data
if is_list_of(image_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
model_image_input = _fuyu_image_preprocess(image_processor, image_data)
image_patches = torch.cat([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
new_multi_modal_data["image"] = image_patches
if not mm_data:
# Avoid warning from HF logger for text-only input
# Input_ids format: bos_token_id + prompt_token_ids + boa_token_id
# Tokenizer won't add boa_token_id by default, we add it manually.
tokenizer = self._get_tokenizer()
boa_token_id: int = tokenizer.vocab["<0x04>"] # type: ignore
prompt_ids = tokenizer.encode(prompt) + [boa_token_id]
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
elif is_list_of(image_list, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
# process prompts
prompt = inputs.get("prompt")
prompt_token_ids = inputs["prompt_token_ids"]
tokenizer = cached_get_tokenizer(model_config.model)
# dim0 is batch_size, dim1 is subseq_size which will always be 1
image_input_ids: List[List[
torch.Tensor]] = model_image_input["image_input_ids"]
image_input_ids = image_input_ids[0][0].tolist()
bos_token = tokenizer.encode("<s>", add_special_tokens=False)[1:]
boa_token = tokenizer.encode("\x04", add_special_tokens=False)[1:]
image_patches = processed_outputs.get("image_patches")
if image_patches is not None:
images = mm_data["images"]
assert isinstance(images, list)
new_prompt = prompt + "\x04"
new_prompt_token_ids = image_input_ids + bos_token + prompt_token_ids[
1:] + boa_token
# Original output: (1, num_images, Pn, Px * Py * C)
# New output: (num_images, Pn, Px * Py * C)
assert (isinstance(image_patches, list)
and len(image_patches) == 1)
assert (isinstance(image_patches[0], torch.Tensor)
and len(image_patches[0]) == len(images))
return token_inputs(prompt=new_prompt,
prompt_token_ids=new_prompt_token_ids,
multi_modal_data=new_multi_modal_data)
processed_outputs["image_patches"] = image_patches[0]
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(image_patches=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.ctx.get_hf_config(FuyuConfig)
bos_token_id = hf_config.bos_token_id
tokenizer = self._get_tokenizer()
eot_token_id = tokenizer.bos_token_id
assert isinstance(eot_token_id, int)
hf_processor = self._get_hf_processor()
image_processor: FuyuImageProcessor = hf_processor.image_processor
target_size = image_processor.size
target_height, target_width = (target_size["height"],
target_size["width"])
def get_replacement_fuyu(item_idx: int):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
width, height = image_size.width, image_size.height
if not (width <= target_width and height <= target_height):
height_scale_factor = target_height / height
width_scale_factor = target_width / width
optimal_scale_factor = min(height_scale_factor,
width_scale_factor)
height = int(height * optimal_scale_factor)
width = int(width * optimal_scale_factor)
ncols, nrows = _get_fuyu_num_image_tokens(
image_width=width,
image_height=height,
)
return (([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows +
[bos_token_id])
return [
PromptReplacement(
modality="image",
target=[eot_token_id],
replacement=get_replacement_fuyu,
)
]
def apply(
self,
prompt_text: str,
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt_text, mm_data, hf_processor_mm_kwargs)
# Only |SPEAKER| (image) tokens should be considered as placeholders,
# so we ignore the trailing bos_token_id
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"], length=p["length"] - 1)
for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
def _get_dummy_mm_inputs(
self,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
num_images = mm_counts.get("image", 0)
mm_data = {
"image":
self._get_dummy_images(width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
num_images=num_images)
}
return ProcessorInputs(
prompt_text="",
mm_data=mm_data,
)
def input_mapper_for_fuyu(ctx: InputContext, data: object):
model_config = ctx.model_config
data_list = data if isinstance(data, list) else [data]
if is_list_of(data_list, Image.Image):
# Fuyu's image_processor can also finish token padding
image_processor: FuyuImageProcessor = cached_get_image_processor(
model_config.model)
model_image_input = _fuyu_image_preprocess(image_processor, data_list)
data = torch.stack([
image_patch[0]
for image_patch in model_image_input["image_patches"]
])
# image has been processed with prompt in input processor
return MultiModalKwargs({"pixel_values": data})
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_fuyu)
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu)
@INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu)
@MULTIMODAL_REGISTRY.register_processor(FuyuMultiModalProcessor)
class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
@@ -280,28 +293,32 @@ class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
return data.to(self.vision_embed_tokens.weight.dtype)
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
pixel_values = kwargs.pop("pixel_values", None)
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
self, **kwargs: object) -> Optional[FuyuImagePatchInputs]:
image_patches = kwargs.pop("image_patches", None)
if image_patches is not None:
if not isinstance(image_patches, (torch.Tensor, list)):
raise ValueError("Incorrect type of image patches. "
f"Got type: {type(pixel_values)}")
f"Got type: {type(image_patches)}")
return FuyuImagePixelInputs(
type="pixel_values",
image_patches_flat = flatten_bn(image_patches)
return FuyuImagePatchInputs(
type="image_patches",
data=self._validate_pixel_values(
flatten_bn(pixel_values, concat=True)),
flatten_bn(image_patches_flat, concat=True)),
patches_per_image=[x.size(0) for x in image_patches_flat],
)
return None
def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
self, image_input: FuyuImagePatchInputs) -> NestedTensors:
image_patches = image_input["data"]
patches_per_image = image_input["patches_per_image"]
assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings
vision_embeddings, _ = self.vision_embed_tokens(image_patches)
return vision_embeddings.split(patches_per_image, dim=0)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)