[Model] merged input processor for Phi-3-Vision models (#10977)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -12,22 +12,18 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import itertools
|
||||
import re
|
||||
from functools import cached_property, lru_cache
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
|
||||
Tuple, TypedDict, Union)
|
||||
from functools import cached_property
|
||||
from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from PIL import Image
|
||||
from transformers import CLIPVisionConfig, PretrainedConfig
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, PretrainedConfig,
|
||||
ProcessorMixin)
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData,
|
||||
InputContext, token_inputs)
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
@@ -36,12 +32,18 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
from vllm.model_executor.models.clip import CLIPVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
|
||||
from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token
|
||||
from vllm.multimodal.image import cached_get_image_processor
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext,
|
||||
ModalityProcessingMetadata,
|
||||
MultiModalDataDict,
|
||||
MultiModalProcessingMetadata,
|
||||
PromptReplacement)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .clip import dummy_image_for_clip, dummy_seq_data_for_clip
|
||||
from .clip import dummy_image_for_clip
|
||||
from .interfaces import SupportsMultiModal, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
init_vllm_registered_model, maybe_prefix,
|
||||
@@ -303,231 +305,99 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
return image_features_hd_newline
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
|
||||
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
|
||||
target_height = int(np.ceil(height / padding_unit) * padding_unit)
|
||||
top_padding = int((target_height - height) / 2)
|
||||
bottom_padding = target_height - height - top_padding
|
||||
padded_width = width
|
||||
padded_height = height + top_padding + bottom_padding
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L90
|
||||
def _calc_hd_transform_size(*, width: int, height: int, hd_num: int):
|
||||
transposed = False
|
||||
if width < height:
|
||||
width, height = height, width
|
||||
transposed = True
|
||||
|
||||
ratio = width / height
|
||||
scale = 1
|
||||
while scale * np.ceil(scale / ratio) <= hd_num:
|
||||
scale += 1
|
||||
scale -= 1
|
||||
|
||||
new_width = int(scale * 336)
|
||||
new_height = int(new_width / ratio)
|
||||
|
||||
padded_width, padded_height = _calc_padded_size(width=new_width,
|
||||
height=new_height)
|
||||
|
||||
if transposed:
|
||||
padded_width, padded_height = padded_height, padded_width
|
||||
|
||||
return padded_width, padded_height
|
||||
|
||||
|
||||
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L181
|
||||
def get_phi3v_image_feature_size(
|
||||
hf_config: Dict[str, Any],
|
||||
*,
|
||||
input_height: int,
|
||||
input_width: int,
|
||||
num_crops: int,
|
||||
) -> int:
|
||||
if num_crops is None:
|
||||
num_crops = hf_config.get("num_crops", 16)
|
||||
new_width, new_height = _calc_hd_transform_size(width=input_width,
|
||||
height=input_height,
|
||||
hd_num=num_crops)
|
||||
|
||||
return (new_height // 336 * new_width // 336 + 1) * 144 + 1 \
|
||||
+ (new_height // 336 + 1) * 12
|
||||
|
||||
|
||||
def get_max_phi3v_image_tokens(ctx: InputContext,
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
mm_processor_kwargs = {}
|
||||
if num_crops is not None:
|
||||
mm_processor_kwargs["num_crops"] = num_crops
|
||||
|
||||
return get_phi3v_image_feature_size(
|
||||
ctx.get_hf_image_processor_config(),
|
||||
input_height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
input_width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
num_crops=num_crops,
|
||||
model_config = ctx.model_config
|
||||
image_processor = cached_get_image_processor(
|
||||
model_config.model,
|
||||
trust_remote_code=model_config.trust_remote_code,
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
|
||||
num_tokens = image_processor.calc_num_image_tokens_from_image_size(
|
||||
width=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
height=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
return num_tokens
|
||||
|
||||
def dummy_data_for_phi3v(ctx: InputContext,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
|
||||
def dummy_mm_kwargs_for_phi3v(ctx: InputProcessingContext,
|
||||
mm_counts: Mapping[str, int]):
|
||||
num_images = mm_counts["image"]
|
||||
|
||||
image_feature_size = get_max_phi3v_image_tokens(ctx, num_crops=num_crops)
|
||||
|
||||
seq_data, ranges = dummy_seq_data_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
seq_len,
|
||||
num_images,
|
||||
image_token_id=_IMAGE_TOKEN_ID,
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
data = dummy_image_for_clip(
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG,
|
||||
num_images,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
return DummyData(seq_data, mm_data, ranges)
|
||||
hf_processor = ctx.get_hf_processor()
|
||||
image_processor = hf_processor.image_processor # type: ignore
|
||||
hf_inputs = image_processor.preprocess(data['image'], return_tensors="pt")
|
||||
|
||||
return MultiModalKwargs(**hf_inputs)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def _get_image_placeholder_token_id_candidates(
|
||||
model_config: ModelConfig,
|
||||
idx: int,
|
||||
) -> List[List[int]]:
|
||||
assert idx > 0
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer)
|
||||
|
||||
# This is used when the image token is at the start of the string
|
||||
start_candidate = tokenizer.encode(f"<|image_{idx}|>",
|
||||
add_special_tokens=False)
|
||||
|
||||
# This is used when the image token is in the middle of the string
|
||||
# We need to get the token for "<", not "▁<"
|
||||
# https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/raw/main/tokenizer.json
|
||||
a_token_id, = tokenizer.encode("a", add_special_tokens=False)
|
||||
a_token_id_, *middle_candidate = tokenizer.encode(f"a<|image_{idx}|>",
|
||||
add_special_tokens=False)
|
||||
assert a_token_id == a_token_id_
|
||||
|
||||
return [start_candidate, middle_candidate]
|
||||
def create_metadata_for_phi3v(
|
||||
ctx: InputProcessingContext) -> MultiModalProcessingMetadata:
|
||||
return {
|
||||
"image":
|
||||
ModalityProcessingMetadata(prompt_repls=[
|
||||
PromptReplacement(target=[_IMAGE_TOKEN_ID],
|
||||
repl_unit=[_IMAGE_TOKEN_ID],
|
||||
repl_count=get_max_phi3v_image_tokens(ctx)),
|
||||
]),
|
||||
}
|
||||
|
||||
|
||||
def input_processor_for_phi3v(ctx: InputContext,
|
||||
inputs: DecoderOnlyInputs,
|
||||
*,
|
||||
num_crops: Optional[int] = None):
|
||||
multi_modal_data = inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return inputs
|
||||
class Phi3VProcessor(BaseMultiModalProcessor):
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_image_processor_config()
|
||||
def __init__(self, ctx: InputProcessingContext) -> None:
|
||||
super().__init__(
|
||||
ctx=ctx,
|
||||
metadata=create_metadata_for_phi3v(ctx),
|
||||
)
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
w, h = image_data.size
|
||||
image_feature_size = [
|
||||
get_phi3v_image_feature_size(hf_config,
|
||||
input_width=w,
|
||||
input_height=h,
|
||||
num_crops=num_crops)
|
||||
]
|
||||
image_data = [image_data]
|
||||
elif is_list_of(image_data, Image.Image):
|
||||
image_feature_size = []
|
||||
for image in image_data:
|
||||
w, h = image.size
|
||||
image_feature_size.append(
|
||||
get_phi3v_image_feature_size(hf_config,
|
||||
input_width=w,
|
||||
input_height=h,
|
||||
num_crops=num_crops))
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
image_feature_size = [image_data.shape[0]]
|
||||
image_data = [image_data]
|
||||
elif is_list_of(image_data, torch.Tensor):
|
||||
image_feature_size = [item.shape[0] for item in image_data]
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
def _get_hf_processor(
|
||||
self,
|
||||
*,
|
||||
num_crops: Optional[int] = None,
|
||||
) -> ProcessorMixin:
|
||||
if num_crops is not None:
|
||||
return self.ctx.get_hf_processor(num_crops=num_crops)
|
||||
return self.ctx.get_hf_processor()
|
||||
|
||||
prompt = inputs.get("prompt")
|
||||
if prompt is None:
|
||||
# for async server request, we assume prompt and its token_ids is always
|
||||
# in correct format. And num_image_tags == len(image_data) always True.
|
||||
image_idx = range(1, len(image_data) + 1)
|
||||
new_prompt = None
|
||||
else:
|
||||
image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt)))
|
||||
if prompt.count("<|image|>") > 0:
|
||||
logger.warning("Please follow the prompt format that is "
|
||||
"documented on HuggingFace which does not involve "
|
||||
"repeating <|image|> tokens.")
|
||||
elif (num_image_tags := len(image_idx)) > 1:
|
||||
assert num_image_tags == len(
|
||||
image_data), "The count of image_placeholder not match image's"
|
||||
new_prompt = prompt
|
||||
def _apply_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
mm_processor_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._apply_hf_processor(
|
||||
prompt, mm_data, mm_processor_kwargs)
|
||||
# Phi3v processor has inserted -1, -2 etc as placeholder in prompt_ids,
|
||||
# which will cause OverflowError when decoding the prompt_ids.
|
||||
# Therefore, we need to do an early replacement here
|
||||
token_ids = processed_outputs['input_ids']
|
||||
token_ids[token_ids < 0] = _IMAGE_TOKEN_ID
|
||||
processed_outputs['input_ids'] = token_ids
|
||||
return processed_outputs
|
||||
|
||||
prompt_token_ids = inputs["prompt_token_ids"].copy()
|
||||
|
||||
# masked placeholder with image token id
|
||||
for idx in image_idx:
|
||||
candidates = _get_image_placeholder_token_id_candidates(model_config,
|
||||
idx=idx)
|
||||
|
||||
for candidate in candidates:
|
||||
for i in range(len(prompt_token_ids) - len(candidate) + 1):
|
||||
if prompt_token_ids[i:i + len(candidate)] == candidate:
|
||||
prompt_token_ids[i:i +
|
||||
len(candidate)] = ([_IMAGE_TOKEN_ID] *
|
||||
len(candidate))
|
||||
break
|
||||
|
||||
# merge consecutive tag ids
|
||||
merged_token_ids: List[int] = []
|
||||
for is_placeholder, token_ids in itertools.groupby(
|
||||
prompt_token_ids, lambda x: x == _IMAGE_TOKEN_ID):
|
||||
if is_placeholder:
|
||||
merged_token_ids.append(_IMAGE_TOKEN_ID)
|
||||
else:
|
||||
merged_token_ids.extend(list(token_ids))
|
||||
|
||||
# TODO: Move this to utils or integrate with clip.
|
||||
new_token_ids: List[int] = []
|
||||
placeholder_ranges: List[PlaceholderRange] = []
|
||||
placeholder_idx = 0
|
||||
while merged_token_ids:
|
||||
token_id = merged_token_ids.pop(0)
|
||||
if token_id == _IMAGE_TOKEN_ID:
|
||||
replacement_ids = repeat_and_pad_token(
|
||||
_IMAGE_TOKEN_ID,
|
||||
repeat_count=image_feature_size[placeholder_idx],
|
||||
)
|
||||
placeholder_ranges.append({
|
||||
"offset": len(new_token_ids),
|
||||
"length": len(replacement_ids)
|
||||
})
|
||||
new_token_ids.extend(replacement_ids)
|
||||
placeholder_idx += 1
|
||||
else:
|
||||
new_token_ids.append(token_id)
|
||||
|
||||
# NOTE: Create a defensive copy of the original inputs
|
||||
return token_inputs(prompt_token_ids=new_token_ids,
|
||||
prompt=new_prompt,
|
||||
multi_modal_data=multi_modal_data,
|
||||
multi_modal_placeholders={"image": placeholder_ranges})
|
||||
def _get_dummy_mm_kwargs(
|
||||
self,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalKwargs:
|
||||
return dummy_mm_kwargs_for_phi3v(self.ctx, mm_counts)
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper()
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v)
|
||||
@MULTIMODAL_REGISTRY.register_processor(Phi3VProcessor)
|
||||
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
||||
Reference in New Issue
Block a user