[Model] Refactor Step3-VL processor to HF style (#37579)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -39,7 +39,11 @@ from vllm.multimodal.processing import (
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.step3_vl import Step3VisionEncoderConfig
|
||||
from vllm.transformers_utils.processors.step3_vl import Step3VLProcessor
|
||||
from vllm.transformers_utils.processors.step3_vl import (
|
||||
MAX_IMAGE_SIZE,
|
||||
Step3VLImageProcessor,
|
||||
Step3VLProcessor,
|
||||
)
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
|
||||
@@ -86,21 +90,30 @@ Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingI
|
||||
|
||||
|
||||
class Step3VLProcessingInfo(BaseProcessingInfo):
|
||||
def get_image_processor(self, **kwargs):
|
||||
config = self.get_hf_config()
|
||||
|
||||
kwargs.setdefault(
|
||||
"enable_patch",
|
||||
getattr(config.vision_config, "enable_patch", True),
|
||||
)
|
||||
|
||||
return Step3VLImageProcessor(**kwargs)
|
||||
|
||||
def get_hf_processor(self) -> Step3VLProcessor:
|
||||
return Step3VLProcessor(
|
||||
self.get_hf_config(),
|
||||
self.get_tokenizer(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
image_processor=self.get_image_processor(),
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_max_image_tokens(self) -> int:
|
||||
hf_processor = self.get_hf_processor()
|
||||
return hf_processor.get_num_image_tokens(
|
||||
self.get_image_size_with_most_features().width,
|
||||
self.get_image_size_with_most_features().height,
|
||||
)
|
||||
image_processor = self.get_image_processor()
|
||||
target_width, target_height = self.get_image_size_with_most_features()
|
||||
|
||||
return image_processor.get_num_image_tokens(target_width, target_height)
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
@@ -110,20 +123,7 @@ class Step3VLProcessingInfo(BaseProcessingInfo):
|
||||
return {"image": self.get_max_image_tokens()}
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
return ImageSize(3024, 3024)
|
||||
|
||||
def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
|
||||
if len(mm_data) != 1 or "image" not in mm_data:
|
||||
raise ValueError("mm_data could only contain one key 'image' for steo1o")
|
||||
|
||||
image_data = mm_data["image"]
|
||||
if not isinstance(image_data, (list, tuple)):
|
||||
image_data = [image_data]
|
||||
|
||||
return sum(
|
||||
self.get_hf_processor().get_num_image_tokens(img.width, img.height)
|
||||
for img in image_data
|
||||
)
|
||||
return ImageSize(MAX_IMAGE_SIZE, MAX_IMAGE_SIZE)
|
||||
|
||||
|
||||
class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
|
||||
@@ -165,13 +165,11 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo])
|
||||
def get_replacement_step1o(item_idx: int):
|
||||
out_item = out_mm_kwargs["image"][item_idx]
|
||||
num_patches = int(out_item["num_patches"].data)
|
||||
if num_patches > 0:
|
||||
patch_newline_mask = out_item["patch_newline_mask"].data
|
||||
image_repl_ids = hf_processor._get_image_repl_features(
|
||||
1, num_patches, patch_newline_mask.tolist()
|
||||
)[1]
|
||||
else:
|
||||
image_repl_ids = hf_processor._get_image_repl_features(1, 0, None)[1]
|
||||
patch_newline_mask = out_item["patch_newline_mask"].data
|
||||
image_repl_ids = hf_processor.get_image_repl_feature_ids(
|
||||
1, num_patches, patch_newline_mask.tolist()
|
||||
)
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
seq=image_repl_ids,
|
||||
embed_token_id=image_placeholder_token_id,
|
||||
|
||||
@@ -558,6 +558,7 @@ class InternVLProcessor(ProcessorMixin):
|
||||
else:
|
||||
text_inputs = {}
|
||||
|
||||
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
|
||||
|
||||
return BatchFeature(combined_outputs, tensor_type=return_tensors)
|
||||
return BatchFeature(
|
||||
data={**text_inputs, **image_inputs, **video_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,6 @@ class KimiK25Processor(ProcessorMixin):
|
||||
self.media_token_id = media_token_id
|
||||
assert self.media_token_id is not None
|
||||
|
||||
# We do not support str input for text here
|
||||
def __call__(
|
||||
self,
|
||||
vision_chunks: list[VisionChunk] | None = None,
|
||||
|
||||
@@ -8,13 +8,13 @@ import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
from transformers import BatchFeature, ProcessorMixin, TensorType
|
||||
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
|
||||
MAX_IMAGE_SIZE: int = 3024
|
||||
|
||||
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool] | None]
|
||||
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool]]
|
||||
|
||||
|
||||
class Step3VisionProcessor:
|
||||
@@ -185,7 +185,7 @@ class ImagePatcher:
|
||||
|
||||
def __call__(
|
||||
self, img: Image.Image
|
||||
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
|
||||
) -> tuple[Image.Image, list[Image.Image], list[bool]]:
|
||||
img_width, img_height = img.size
|
||||
new_img_width, new_img_height = self.get_image_size_for_padding(
|
||||
img_width, img_height
|
||||
@@ -203,7 +203,7 @@ class ImagePatcher:
|
||||
)
|
||||
|
||||
if window_size == 0 or not self.enable_patch:
|
||||
return img, [], None
|
||||
return img, [], []
|
||||
else:
|
||||
new_img_width, new_img_height = self.get_image_size_for_crop(
|
||||
new_img_width, new_img_height, window_size
|
||||
@@ -236,43 +236,28 @@ class ImagePatcher:
|
||||
return (
|
||||
img,
|
||||
patches,
|
||||
[i in newlines for i in range(len(patches))]
|
||||
if len(patches) > 0
|
||||
else None,
|
||||
[i in newlines for i in range(len(patches))],
|
||||
)
|
||||
|
||||
|
||||
class Step3VLProcessor:
|
||||
class Step3VLImageProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: TokenizerLike,
|
||||
image_size: int = 728,
|
||||
patch_size: int = 504,
|
||||
num_image_feature_size: int = 169,
|
||||
num_patch_feature_size: int = 81,
|
||||
enable_patch: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
self.image_size = 728
|
||||
self.patch_size = 504
|
||||
self.image_size = image_size
|
||||
self.patch_size = patch_size
|
||||
self.num_image_feature_size = num_image_feature_size
|
||||
self.num_patch_feature_size = num_patch_feature_size
|
||||
self.image_preprocessor = Step3VisionProcessor(
|
||||
self.image_size, "bilinear", self.patch_size
|
||||
image_size, "bilinear", patch_size
|
||||
)
|
||||
|
||||
self.num_image_feature_size = 169
|
||||
self.num_patch_feature_size = 81
|
||||
self.image_token = "<im_patch>"
|
||||
self.image_feature_placeholder = self.image_token * self.num_image_feature_size
|
||||
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
|
||||
|
||||
# Respect vision config switch to enable/disable patch extraction.
|
||||
# For video understanding, it's preferable to disable patch.
|
||||
enable_patch = getattr(self.config.vision_config, "enable_patch", True)
|
||||
self.patcher = ImagePatcher(enable_patch=enable_patch)
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self.tokenizer.get_vocab()[self.image_token]
|
||||
|
||||
def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
|
||||
num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height)
|
||||
|
||||
@@ -299,58 +284,168 @@ class Step3VLProcessor:
|
||||
for img in images
|
||||
]
|
||||
|
||||
def _get_patch_repl(
|
||||
def __call__(
|
||||
self,
|
||||
num_patches: int,
|
||||
patch_newline_mask: list[bool] | None,
|
||||
) -> tuple[str, list[int]]:
|
||||
text = ""
|
||||
token_ids = []
|
||||
for i in range(num_patches):
|
||||
assert (
|
||||
patch_newline_mask is not None
|
||||
and len(patch_newline_mask) == num_patches
|
||||
)
|
||||
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
|
||||
token_ids.extend(
|
||||
[self.tokenizer.convert_tokens_to_ids("<patch_start>")]
|
||||
+ [self.image_token_id] * self.num_patch_feature_size
|
||||
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")]
|
||||
)
|
||||
if patch_newline_mask and patch_newline_mask[i]:
|
||||
text += "<patch_newline>"
|
||||
token_ids.append(
|
||||
self.tokenizer.convert_tokens_to_ids("<patch_newline>")
|
||||
)
|
||||
return text, token_ids
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
) -> BatchFeature:
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
def _get_image_repl(
|
||||
split_images_data = self._split_images(images)
|
||||
pixel_values_lst = []
|
||||
patch_pixel_values_lst = []
|
||||
patch_newline_mask_lst = []
|
||||
num_patches = []
|
||||
for raw_img, img_patches, patch_newline_mask in split_images_data:
|
||||
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
|
||||
num_patches.append(len(img_patches))
|
||||
patch_pixel_values_lst.extend(
|
||||
self._convert_images_to_pixel_values(img_patches, is_patch=True)
|
||||
)
|
||||
patch_newline_mask_lst.extend(patch_newline_mask)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_lst)
|
||||
patch_size = self.patch_size
|
||||
image_inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"num_patches": num_patches,
|
||||
"patch_pixel_values": (
|
||||
torch.cat(patch_pixel_values_lst)
|
||||
if patch_pixel_values_lst
|
||||
else pixel_values.new_empty((0, 3, patch_size, patch_size))
|
||||
),
|
||||
"patch_newline_mask": torch.tensor(
|
||||
patch_newline_mask_lst, dtype=torch.bool
|
||||
),
|
||||
}
|
||||
return BatchFeature(image_inputs, tensor_type=return_tensors)
|
||||
|
||||
|
||||
class Step3VLProcessor(ProcessorMixin):
|
||||
attributes = ["image_processor", "tokenizer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_images: int,
|
||||
) -> tuple[str, list[int]]:
|
||||
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
|
||||
token_ids = (
|
||||
[self.tokenizer.convert_tokens_to_ids("<im_start>")]
|
||||
+ [self.image_token_id] * self.num_image_feature_size
|
||||
+ [self.tokenizer.convert_tokens_to_ids("<im_end>")]
|
||||
image_processor: Step3VLImageProcessor,
|
||||
tokenizer: TokenizerLike,
|
||||
) -> None:
|
||||
self.image_processor = image_processor
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.image_start_token = image_start_token = "<im_start>"
|
||||
self.image_end_token = image_end_token = "<im_end>"
|
||||
self.patch_start_token = patch_start_token = "<patch_start>"
|
||||
self.patch_end_token = patch_end_token = "<patch_end>"
|
||||
self.patch_newline_token = patch_newline_token = "<patch_newline>"
|
||||
self.image_start_token_id = tokenizer.convert_tokens_to_ids(image_start_token)
|
||||
self.image_end_token_id = tokenizer.convert_tokens_to_ids(image_end_token)
|
||||
self.patch_start_token_id = tokenizer.convert_tokens_to_ids(patch_start_token)
|
||||
self.patch_end_token_id = tokenizer.convert_tokens_to_ids(patch_end_token)
|
||||
self.patch_newline_token_id = tokenizer.convert_tokens_to_ids(
|
||||
patch_newline_token
|
||||
)
|
||||
return text * num_images, token_ids * num_images
|
||||
|
||||
def _get_image_repl_features(
|
||||
self.image_token = image_token = "<im_patch>"
|
||||
self.image_feature_tokens = image_token * image_processor.num_image_feature_size
|
||||
self.patch_feature_tokens = image_token * image_processor.num_patch_feature_size
|
||||
|
||||
self.image_token_id = image_token_id = tokenizer.convert_tokens_to_ids(
|
||||
image_token
|
||||
)
|
||||
self.image_feature_token_ids = [
|
||||
image_token_id
|
||||
] * image_processor.num_image_feature_size
|
||||
self.patch_feature_token_ids = [
|
||||
image_token_id
|
||||
] * image_processor.num_patch_feature_size
|
||||
|
||||
def _get_patch_repl_text(
|
||||
self,
|
||||
num_patches: int,
|
||||
patch_newline_mask: list[bool],
|
||||
) -> str:
|
||||
assert len(patch_newline_mask) == num_patches
|
||||
|
||||
parts = []
|
||||
for i in range(num_patches):
|
||||
parts.extend(
|
||||
[
|
||||
self.patch_start_token,
|
||||
self.patch_feature_tokens,
|
||||
self.patch_end_token,
|
||||
]
|
||||
)
|
||||
if patch_newline_mask[i]:
|
||||
parts.append(self.patch_newline_token)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
def _get_patch_repl_ids(
|
||||
self,
|
||||
num_patches: int,
|
||||
patch_newline_mask: list[bool],
|
||||
) -> list[int]:
|
||||
assert len(patch_newline_mask) == num_patches
|
||||
|
||||
parts = []
|
||||
for i in range(num_patches):
|
||||
parts.extend(
|
||||
[
|
||||
self.patch_start_token_id,
|
||||
*self.patch_feature_token_ids,
|
||||
self.patch_end_token_id,
|
||||
]
|
||||
)
|
||||
if patch_newline_mask[i]:
|
||||
parts.append(self.patch_newline_token_id)
|
||||
|
||||
return parts
|
||||
|
||||
def _get_image_repl_text(
|
||||
self,
|
||||
num_images: int,
|
||||
) -> str:
|
||||
parts = [
|
||||
self.image_start_token,
|
||||
self.image_feature_tokens,
|
||||
self.image_end_token,
|
||||
] * num_images
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
def _get_image_repl_ids(
|
||||
self,
|
||||
num_images: int,
|
||||
) -> list[int]:
|
||||
part = [
|
||||
self.image_start_token_id,
|
||||
*self.image_feature_token_ids,
|
||||
self.image_end_token_id,
|
||||
]
|
||||
return part * num_images
|
||||
|
||||
def get_image_repl_feature_text(
|
||||
self,
|
||||
num_images: int,
|
||||
num_patches: int,
|
||||
patch_new_line_idx: list[bool] | None,
|
||||
) -> tuple[str, list[int]]:
|
||||
if num_patches > 0:
|
||||
patch_repl, patch_repl_ids = self._get_patch_repl(
|
||||
num_patches, patch_new_line_idx
|
||||
)
|
||||
else:
|
||||
patch_repl = ""
|
||||
patch_repl_ids = []
|
||||
image_repl, image_repl_ids = self._get_image_repl(num_images)
|
||||
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
|
||||
patch_new_line_idx: list[bool],
|
||||
) -> str:
|
||||
patch_repl = self._get_patch_repl_text(num_patches, patch_new_line_idx)
|
||||
image_repl = self._get_image_repl_text(num_images)
|
||||
return patch_repl + image_repl
|
||||
|
||||
def get_image_repl_feature_ids(
|
||||
self,
|
||||
num_images: int,
|
||||
num_patches: int,
|
||||
patch_new_line_idx: list[bool],
|
||||
) -> list[int]:
|
||||
patch_repl = self._get_patch_repl_ids(num_patches, patch_new_line_idx)
|
||||
image_repl = self._get_image_repl_ids(num_images)
|
||||
return patch_repl + image_repl
|
||||
|
||||
def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str:
|
||||
parts = text.split(placeholder)
|
||||
@@ -373,69 +468,44 @@ class Step3VLProcessor:
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
if len(images) == 0:
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(
|
||||
images=images,
|
||||
return_tensors=return_tensors,
|
||||
)
|
||||
num_patches = image_inputs["num_patches"]
|
||||
patch_newline_mask = image_inputs["patch_newline_mask"]
|
||||
else:
|
||||
image_inputs = {}
|
||||
num_patches = []
|
||||
patch_newline_mask = []
|
||||
|
||||
if text is not None:
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
if image_inputs:
|
||||
image_token = self.image_token
|
||||
image_repl_str_lst = []
|
||||
start = 0
|
||||
for n_patches in num_patches:
|
||||
image_repl_str = self.get_image_repl_feature_text(
|
||||
1, n_patches, patch_newline_mask[start : start + n_patches]
|
||||
)
|
||||
image_repl_str_lst.append(image_repl_str)
|
||||
|
||||
start += n_patches
|
||||
|
||||
text = [
|
||||
self.replace_placeholder(t, image_token, image_repl_str_lst)
|
||||
for t in text
|
||||
]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
else:
|
||||
split_images_data = self._split_images(images)
|
||||
pixel_values_lst = []
|
||||
patch_pixel_values_lst = []
|
||||
patch_newline_mask_lst = []
|
||||
image_repl_str_lst = []
|
||||
image_repl_ids_lst = []
|
||||
num_patches = []
|
||||
for raw_img, img_patches, patch_newline_mask in split_images_data:
|
||||
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
|
||||
|
||||
if len(img_patches) > 0:
|
||||
patch_pixel_values_lst.extend(
|
||||
self._convert_images_to_pixel_values(img_patches, is_patch=True)
|
||||
)
|
||||
num_patches.append(len(img_patches))
|
||||
|
||||
image_repl_str, image_repl_ids = self._get_image_repl_features(
|
||||
1, len(img_patches), patch_newline_mask
|
||||
)
|
||||
image_repl_str_lst.append(image_repl_str)
|
||||
image_repl_ids_lst.extend(image_repl_ids)
|
||||
|
||||
if patch_newline_mask is not None:
|
||||
patch_newline_mask_lst.extend(patch_newline_mask)
|
||||
|
||||
pixel_values = torch.cat(pixel_values_lst)
|
||||
patch_size = self.patch_size
|
||||
image_inputs = {
|
||||
"pixel_values": pixel_values,
|
||||
"num_patches": num_patches,
|
||||
"patch_pixel_values": (
|
||||
torch.cat(patch_pixel_values_lst)
|
||||
if patch_pixel_values_lst
|
||||
else pixel_values.new_empty((0, 3, patch_size, patch_size))
|
||||
),
|
||||
"patch_newline_mask": torch.tensor(
|
||||
patch_newline_mask_lst, dtype=torch.bool
|
||||
),
|
||||
}
|
||||
|
||||
text = [
|
||||
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
|
||||
for t in text
|
||||
]
|
||||
text_inputs = self.tokenizer(text)
|
||||
text_inputs = {}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
data={**text_inputs, **image_inputs},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user