[Model] merged input processor for Phi-3-Vision models (#10977)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py
2024-12-10 04:55:10 +08:00
committed by GitHub
parent ca871491ed
commit a811dd6608
7 changed files with 234 additions and 408 deletions

View File

@@ -12,22 +12,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import re
from functools import cached_property, lru_cache
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from functools import cached_property
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
TypedDict, Union)
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
ProcessorMixin)
from vllm.attention import AttentionMetadata
from vllm.config import ModelConfig, VllmConfig
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
InputContext, token_inputs)
from vllm.config import VllmConfig
from vllm.inputs import InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
@@ -36,12 +32,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.models.clip import CLIPVisionModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
from vllm.multimodal.image import cached_get_image_processor
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
from vllm.multimodal.processing import (BaseMultiModalProcessor,
InputProcessingContext,
ModalityProcessingMetadata,
MultiModalDataDict,
MultiModalProcessingMetadata,
PromptReplacement)
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
from .clip import dummy_image_for_clip
from .interfaces import SupportsMultiModal, SupportsPP
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
@@ -303,231 +305,99 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
bottom_padding = target_height - height - top_padding
padded_width = width
padded_height = height + top_padding + bottom_padding
return padded_width, padded_height
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
transposed = False
if width < height:
width, height = height, width
transposed = True
ratio = width / height
scale = 1
while scale * np.ceil(scale / ratio) <= hd_num:
scale += 1
scale -= 1
new_width = int(scale * 336)
new_height = int(new_width / ratio)
padded_width, padded_height = _calc_padded_size(width=new_width,
height=new_height)
if transposed:
padded_width, padded_height = padded_height, padded_width
return padded_width, padded_height
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
def get_phi3v_image_feature_size(
hf_config: Dict[str, Any],
*,
input_height: int,
input_width: int,
num_crops: int,
) -> int:
if num_crops is None:
num_crops = hf_config.get("num_crops", 16)
new_width, new_height = _calc_hd_transform_size(width=input_width,
height=input_height,
hd_num=num_crops)
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
+ (new_height // 336 + 1) * 12
def get_max_phi3v_image_tokens(ctx: InputContext,
*,
num_crops: Optional[int] = None):
mm_processor_kwargs = {}
if num_crops is not None:
mm_processor_kwargs["num_crops"] = num_crops
return get_phi3v_image_feature_size(
ctx.get_hf_image_processor_config(),
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
num_crops=num_crops,
model_config = ctx.model_config
image_processor = cached_get_image_processor(
model_config.model,
trust_remote_code=model_config.trust_remote_code,
**mm_processor_kwargs,
)
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return num_tokens
def dummy_data_for_phi3v(ctx: InputContext,
seq_len: int,
mm_counts: Mapping[str, int],
*,
num_crops: Optional[int] = None):
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
mm_counts: Mapping[str, int]):
num_images = mm_counts["image"]
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
seq_data, ranges = dummy_seq_data_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
seq_len,
num_images,
image_token_id=_IMAGE_TOKEN_ID,
image_feature_size_override=image_feature_size,
)
mm_data = dummy_image_for_clip(
data = dummy_image_for_clip(
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
num_images,
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
)
return DummyData(seq_data, mm_data, ranges)
hf_processor = ctx.get_hf_processor()
image_processor = hf_processor.image_processor # type: ignore
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
return MultiModalKwargs(**hf_inputs)
@lru_cache
def _get_image_placeholder_token_id_candidates(
model_config: ModelConfig,
idx: int,
) -> List[List[int]]:
assert idx > 0
tokenizer = cached_get_tokenizer(model_config.tokenizer)
# This is used when the image token is at the start of the string
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
add_special_tokens=False)
# This is used when the image token is in the middle of the string
# We need to get the token for "<", not "▁<"
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
add_special_tokens=False)
assert a_token_id == a_token_id_
return [start_candidate, middle_candidate]
def create_metadata_for_phi3v(
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
return {
"image":
ModalityProcessingMetadata(prompt_repls=[
PromptReplacement(target=[_IMAGE_TOKEN_ID],
repl_unit=[_IMAGE_TOKEN_ID],
repl_count=get_max_phi3v_image_tokens(ctx)),
]),
}
def input_processor_for_phi3v(ctx: InputContext,
inputs: DecoderOnlyInputs,
*,
num_crops: Optional[int] = None):
multi_modal_data = inputs.get("multi_modal_data")
if multi_modal_data is None or "image" not in multi_modal_data:
return inputs
class Phi3VProcessor(BaseMultiModalProcessor):
model_config = ctx.model_config
hf_config = ctx.get_hf_image_processor_config()
def __init__(self, ctx: InputProcessingContext) -> None:
super().__init__(
ctx=ctx,
metadata=create_metadata_for_phi3v(ctx),
)
image_data = multi_modal_data["image"]
if isinstance(image_data, Image.Image):
w, h = image_data.size
image_feature_size = [
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h,
num_crops=num_crops)
]
image_data = [image_data]
elif is_list_of(image_data, Image.Image):
image_feature_size = []
for image in image_data:
w, h = image.size
image_feature_size.append(
get_phi3v_image_feature_size(hf_config,
input_width=w,
input_height=h,
num_crops=num_crops))
elif isinstance(image_data, torch.Tensor):
image_feature_size = [image_data.shape[0]]
image_data = [image_data]
elif is_list_of(image_data, torch.Tensor):
image_feature_size = [item.shape[0] for item in image_data]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
def _get_hf_processor(
self,
*,
num_crops: Optional[int] = None,
) -> ProcessorMixin:
if num_crops is not None:
return self.ctx.get_hf_processor(num_crops=num_crops)
return self.ctx.get_hf_processor()
prompt = inputs.get("prompt")
if prompt is None:
# for async server request, we assume prompt and its token_ids is always
# in correct format. And num_image_tags == len(image_data) always True.
image_idx = range(1, len(image_data) + 1)
new_prompt = None
else:
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
if prompt.count("<|image|>") > 0:
logger.warning("Please follow the prompt format that is "
"documented on HuggingFace which does not involve "
"repeating <|image|> tokens.")
elif (num_image_tags := len(image_idx)) > 1:
assert num_image_tags == len(
image_data), "The count of image_placeholder not match image's"
new_prompt = prompt
def _apply_hf_processor(
self,
prompt: str,
mm_data: MultiModalDataDict,
mm_processor_kwargs: Mapping[str, object],
) -> BatchFeature:
processed_outputs = super()._apply_hf_processor(
prompt, mm_data, mm_processor_kwargs)
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
# which will cause OverflowError when decoding the prompt_ids.
# Therefore, we need to do an early replacement here
token_ids = processed_outputs['input_ids']
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
processed_outputs['input_ids'] = token_ids
return processed_outputs
prompt_token_ids = inputs["prompt_token_ids"].copy()
# masked placeholder with image token id
for idx in image_idx:
candidates = _get_image_placeholder_token_id_candidates(model_config,
idx=idx)
for candidate in candidates:
for i in range(len(prompt_token_ids) - len(candidate) + 1):
if prompt_token_ids[i:i + len(candidate)] == candidate:
prompt_token_ids[i:i +
len(candidate)] = ([_IMAGE_TOKEN_ID] *
len(candidate))
break
# merge consecutive tag ids
merged_token_ids: List[int] = []
for is_placeholder, token_ids in itertools.groupby(
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
if is_placeholder:
merged_token_ids.append(_IMAGE_TOKEN_ID)
else:
merged_token_ids.extend(list(token_ids))
# TODO: Move this to utils or integrate with clip.
new_token_ids: List[int] = []
placeholder_ranges: List[PlaceholderRange] = []
placeholder_idx = 0
while merged_token_ids:
token_id = merged_token_ids.pop(0)
if token_id == _IMAGE_TOKEN_ID:
replacement_ids = repeat_and_pad_token(
_IMAGE_TOKEN_ID,
repeat_count=image_feature_size[placeholder_idx],
)
placeholder_ranges.append({
"offset": len(new_token_ids),
"length": len(replacement_ids)
})
new_token_ids.extend(replacement_ids)
placeholder_idx += 1
else:
new_token_ids.append(token_id)
# NOTE: Create a defensive copy of the original inputs
return token_inputs(prompt_token_ids=new_token_ids,
prompt=new_prompt,
multi_modal_data=multi_modal_data,
multi_modal_placeholders={"image": placeholder_ranges})
def _get_dummy_mm_kwargs(
self,
mm_counts: Mapping[str, int],
) -> MultiModalKwargs:
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
@MULTIMODAL_REGISTRY.register_image_input_mapper()
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):