[Model] Refactor Step3-VL processor to HF style (#37579)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-20 14:05:08 +08:00
committed by GitHub
parent e2d1c8b5e8
commit 30108fc8b0
4 changed files with 235 additions and 167 deletions

View File

@@ -39,7 +39,11 @@ from vllm.multimodal.processing import (
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.step3_vl import Step3VisionEncoderConfig
from vllm.transformers_utils.processors.step3_vl import Step3VLProcessor
from vllm.transformers_utils.processors.step3_vl import (
MAX_IMAGE_SIZE,
Step3VLImageProcessor,
Step3VLProcessor,
)
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
@@ -86,21 +90,30 @@ Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingI
class Step3VLProcessingInfo(BaseProcessingInfo):
def get_image_processor(self, **kwargs):
config = self.get_hf_config()
kwargs.setdefault(
"enable_patch",
getattr(config.vision_config, "enable_patch", True),
)
return Step3VLImageProcessor(**kwargs)
def get_hf_processor(self) -> Step3VLProcessor:
return Step3VLProcessor(
self.get_hf_config(),
self.get_tokenizer(),
tokenizer=self.get_tokenizer(),
image_processor=self.get_image_processor(),
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_max_image_tokens(self) -> int:
hf_processor = self.get_hf_processor()
return hf_processor.get_num_image_tokens(
self.get_image_size_with_most_features().width,
self.get_image_size_with_most_features().height,
)
image_processor = self.get_image_processor()
target_width, target_height = self.get_image_size_with_most_features()
return image_processor.get_num_image_tokens(target_width, target_height)
def get_mm_max_tokens_per_item(
self,
@@ -110,20 +123,7 @@ class Step3VLProcessingInfo(BaseProcessingInfo):
return {"image": self.get_max_image_tokens()}
def get_image_size_with_most_features(self) -> ImageSize:
return ImageSize(3024, 3024)
def get_num_mm_tokens(self, mm_data: MultiModalDataDict) -> int:
if len(mm_data) != 1 or "image" not in mm_data:
raise ValueError("mm_data could only contain one key 'image' for steo1o")
image_data = mm_data["image"]
if not isinstance(image_data, (list, tuple)):
image_data = [image_data]
return sum(
self.get_hf_processor().get_num_image_tokens(img.width, img.height)
for img in image_data
)
return ImageSize(MAX_IMAGE_SIZE, MAX_IMAGE_SIZE)
class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
@@ -165,13 +165,11 @@ class Step3VLMultiModalProcessor(BaseMultiModalProcessor[Step3VLProcessingInfo])
def get_replacement_step1o(item_idx: int):
out_item = out_mm_kwargs["image"][item_idx]
num_patches = int(out_item["num_patches"].data)
if num_patches > 0:
patch_newline_mask = out_item["patch_newline_mask"].data
image_repl_ids = hf_processor._get_image_repl_features(
1, num_patches, patch_newline_mask.tolist()
)[1]
else:
image_repl_ids = hf_processor._get_image_repl_features(1, 0, None)[1]
patch_newline_mask = out_item["patch_newline_mask"].data
image_repl_ids = hf_processor.get_image_repl_feature_ids(
1, num_patches, patch_newline_mask.tolist()
)
return PromptUpdateDetails.select_token_id(
seq=image_repl_ids,
embed_token_id=image_placeholder_token_id,

View File

@@ -558,6 +558,7 @@ class InternVLProcessor(ProcessorMixin):
else:
text_inputs = {}
combined_outputs = {**text_inputs, **image_inputs, **video_inputs}
return BatchFeature(combined_outputs, tensor_type=return_tensors)
return BatchFeature(
data={**text_inputs, **image_inputs, **video_inputs},
tensor_type=return_tensors,
)

View File

@@ -19,7 +19,6 @@ class KimiK25Processor(ProcessorMixin):
self.media_token_id = media_token_id
assert self.media_token_id is not None
# We do not support str input for text here
def __call__(
self,
vision_chunks: list[VisionChunk] | None = None,

View File

@@ -8,13 +8,13 @@ import torch
from PIL import Image
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType
from transformers import BatchFeature, ProcessorMixin, TensorType
from vllm.tokenizers import TokenizerLike
MAX_IMAGE_SIZE: int = 3024
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool] | None]
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[bool]]
class Step3VisionProcessor:
@@ -185,7 +185,7 @@ class ImagePatcher:
def __call__(
self, img: Image.Image
) -> tuple[Image.Image, list[Image.Image], list[bool] | None]:
) -> tuple[Image.Image, list[Image.Image], list[bool]]:
img_width, img_height = img.size
new_img_width, new_img_height = self.get_image_size_for_padding(
img_width, img_height
@@ -203,7 +203,7 @@ class ImagePatcher:
)
if window_size == 0 or not self.enable_patch:
return img, [], None
return img, [], []
else:
new_img_width, new_img_height = self.get_image_size_for_crop(
new_img_width, new_img_height, window_size
@@ -236,43 +236,28 @@ class ImagePatcher:
return (
img,
patches,
[i in newlines for i in range(len(patches))]
if len(patches) > 0
else None,
[i in newlines for i in range(len(patches))],
)
class Step3VLProcessor:
class Step3VLImageProcessor:
def __init__(
self,
config: PretrainedConfig,
tokenizer: TokenizerLike,
image_size: int = 728,
patch_size: int = 504,
num_image_feature_size: int = 169,
num_patch_feature_size: int = 81,
enable_patch: bool = True,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
self.image_size = 728
self.patch_size = 504
self.image_size = image_size
self.patch_size = patch_size
self.num_image_feature_size = num_image_feature_size
self.num_patch_feature_size = num_patch_feature_size
self.image_preprocessor = Step3VisionProcessor(
self.image_size, "bilinear", self.patch_size
image_size, "bilinear", patch_size
)
self.num_image_feature_size = 169
self.num_patch_feature_size = 81
self.image_token = "<im_patch>"
self.image_feature_placeholder = self.image_token * self.num_image_feature_size
self.patch_feature_placeholder = self.image_token * self.num_patch_feature_size
# Respect vision config switch to enable/disable patch extraction.
# For video understanding, it's preferable to disable patch.
enable_patch = getattr(self.config.vision_config, "enable_patch", True)
self.patcher = ImagePatcher(enable_patch=enable_patch)
@property
def image_token_id(self) -> int:
return self.tokenizer.get_vocab()[self.image_token]
def get_num_image_tokens(self, img_width: int, img_height: int) -> int:
num_patches, num_newlines = self.patcher.get_num_patches(img_width, img_height)
@@ -299,58 +284,168 @@ class Step3VLProcessor:
for img in images
]
def _get_patch_repl(
def __call__(
self,
num_patches: int,
patch_newline_mask: list[bool] | None,
) -> tuple[str, list[int]]:
text = ""
token_ids = []
for i in range(num_patches):
assert (
patch_newline_mask is not None
and len(patch_newline_mask) == num_patches
)
text += f"<patch_start>{self.patch_feature_placeholder}<patch_end>"
token_ids.extend(
[self.tokenizer.convert_tokens_to_ids("<patch_start>")]
+ [self.image_token_id] * self.num_patch_feature_size
+ [self.tokenizer.convert_tokens_to_ids("<patch_end>")]
)
if patch_newline_mask and patch_newline_mask[i]:
text += "<patch_newline>"
token_ids.append(
self.tokenizer.convert_tokens_to_ids("<patch_newline>")
)
return text, token_ids
images: Image.Image | list[Image.Image] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if images is None:
images = []
if not isinstance(images, list):
images = [images]
def _get_image_repl(
split_images_data = self._split_images(images)
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
num_patches = []
for raw_img, img_patches, patch_newline_mask in split_images_data:
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
num_patches.append(len(img_patches))
patch_pixel_values_lst.extend(
self._convert_images_to_pixel_values(img_patches, is_patch=True)
)
patch_newline_mask_lst.extend(patch_newline_mask)
pixel_values = torch.cat(pixel_values_lst)
patch_size = self.patch_size
image_inputs = {
"pixel_values": pixel_values,
"num_patches": num_patches,
"patch_pixel_values": (
torch.cat(patch_pixel_values_lst)
if patch_pixel_values_lst
else pixel_values.new_empty((0, 3, patch_size, patch_size))
),
"patch_newline_mask": torch.tensor(
patch_newline_mask_lst, dtype=torch.bool
),
}
return BatchFeature(image_inputs, tensor_type=return_tensors)
class Step3VLProcessor(ProcessorMixin):
attributes = ["image_processor", "tokenizer"]
def __init__(
self,
num_images: int,
) -> tuple[str, list[int]]:
text = f"<im_start>{self.image_feature_placeholder}<im_end>"
token_ids = (
[self.tokenizer.convert_tokens_to_ids("<im_start>")]
+ [self.image_token_id] * self.num_image_feature_size
+ [self.tokenizer.convert_tokens_to_ids("<im_end>")]
image_processor: Step3VLImageProcessor,
tokenizer: TokenizerLike,
) -> None:
self.image_processor = image_processor
self.tokenizer = tokenizer
self.image_start_token = image_start_token = "<im_start>"
self.image_end_token = image_end_token = "<im_end>"
self.patch_start_token = patch_start_token = "<patch_start>"
self.patch_end_token = patch_end_token = "<patch_end>"
self.patch_newline_token = patch_newline_token = "<patch_newline>"
self.image_start_token_id = tokenizer.convert_tokens_to_ids(image_start_token)
self.image_end_token_id = tokenizer.convert_tokens_to_ids(image_end_token)
self.patch_start_token_id = tokenizer.convert_tokens_to_ids(patch_start_token)
self.patch_end_token_id = tokenizer.convert_tokens_to_ids(patch_end_token)
self.patch_newline_token_id = tokenizer.convert_tokens_to_ids(
patch_newline_token
)
return text * num_images, token_ids * num_images
def _get_image_repl_features(
self.image_token = image_token = "<im_patch>"
self.image_feature_tokens = image_token * image_processor.num_image_feature_size
self.patch_feature_tokens = image_token * image_processor.num_patch_feature_size
self.image_token_id = image_token_id = tokenizer.convert_tokens_to_ids(
image_token
)
self.image_feature_token_ids = [
image_token_id
] * image_processor.num_image_feature_size
self.patch_feature_token_ids = [
image_token_id
] * image_processor.num_patch_feature_size
def _get_patch_repl_text(
self,
num_patches: int,
patch_newline_mask: list[bool],
) -> str:
assert len(patch_newline_mask) == num_patches
parts = []
for i in range(num_patches):
parts.extend(
[
self.patch_start_token,
self.patch_feature_tokens,
self.patch_end_token,
]
)
if patch_newline_mask[i]:
parts.append(self.patch_newline_token)
return "".join(parts)
def _get_patch_repl_ids(
self,
num_patches: int,
patch_newline_mask: list[bool],
) -> list[int]:
assert len(patch_newline_mask) == num_patches
parts = []
for i in range(num_patches):
parts.extend(
[
self.patch_start_token_id,
*self.patch_feature_token_ids,
self.patch_end_token_id,
]
)
if patch_newline_mask[i]:
parts.append(self.patch_newline_token_id)
return parts
def _get_image_repl_text(
self,
num_images: int,
) -> str:
parts = [
self.image_start_token,
self.image_feature_tokens,
self.image_end_token,
] * num_images
return "".join(parts)
def _get_image_repl_ids(
self,
num_images: int,
) -> list[int]:
part = [
self.image_start_token_id,
*self.image_feature_token_ids,
self.image_end_token_id,
]
return part * num_images
def get_image_repl_feature_text(
self,
num_images: int,
num_patches: int,
patch_new_line_idx: list[bool] | None,
) -> tuple[str, list[int]]:
if num_patches > 0:
patch_repl, patch_repl_ids = self._get_patch_repl(
num_patches, patch_new_line_idx
)
else:
patch_repl = ""
patch_repl_ids = []
image_repl, image_repl_ids = self._get_image_repl(num_images)
return patch_repl + image_repl, patch_repl_ids + image_repl_ids
patch_new_line_idx: list[bool],
) -> str:
patch_repl = self._get_patch_repl_text(num_patches, patch_new_line_idx)
image_repl = self._get_image_repl_text(num_images)
return patch_repl + image_repl
def get_image_repl_feature_ids(
self,
num_images: int,
num_patches: int,
patch_new_line_idx: list[bool],
) -> list[int]:
patch_repl = self._get_patch_repl_ids(num_patches, patch_new_line_idx)
image_repl = self._get_image_repl_ids(num_images)
return patch_repl + image_repl
def replace_placeholder(self, text: str, placeholder: str, repls: list[str]) -> str:
parts = text.split(placeholder)
@@ -373,69 +468,44 @@ class Step3VLProcessor:
images: Image.Image | list[Image.Image] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if len(images) == 0:
if images is not None:
image_inputs = self.image_processor(
images=images,
return_tensors=return_tensors,
)
num_patches = image_inputs["num_patches"]
patch_newline_mask = image_inputs["patch_newline_mask"]
else:
image_inputs = {}
num_patches = []
patch_newline_mask = []
if text is not None:
if not isinstance(text, list):
text = [text]
if image_inputs:
image_token = self.image_token
image_repl_str_lst = []
start = 0
for n_patches in num_patches:
image_repl_str = self.get_image_repl_feature_text(
1, n_patches, patch_newline_mask[start : start + n_patches]
)
image_repl_str_lst.append(image_repl_str)
start += n_patches
text = [
self.replace_placeholder(t, image_token, image_repl_str_lst)
for t in text
]
text_inputs = self.tokenizer(text)
else:
split_images_data = self._split_images(images)
pixel_values_lst = []
patch_pixel_values_lst = []
patch_newline_mask_lst = []
image_repl_str_lst = []
image_repl_ids_lst = []
num_patches = []
for raw_img, img_patches, patch_newline_mask in split_images_data:
pixel_values_lst.extend(self._convert_images_to_pixel_values([raw_img]))
if len(img_patches) > 0:
patch_pixel_values_lst.extend(
self._convert_images_to_pixel_values(img_patches, is_patch=True)
)
num_patches.append(len(img_patches))
image_repl_str, image_repl_ids = self._get_image_repl_features(
1, len(img_patches), patch_newline_mask
)
image_repl_str_lst.append(image_repl_str)
image_repl_ids_lst.extend(image_repl_ids)
if patch_newline_mask is not None:
patch_newline_mask_lst.extend(patch_newline_mask)
pixel_values = torch.cat(pixel_values_lst)
patch_size = self.patch_size
image_inputs = {
"pixel_values": pixel_values,
"num_patches": num_patches,
"patch_pixel_values": (
torch.cat(patch_pixel_values_lst)
if patch_pixel_values_lst
else pixel_values.new_empty((0, 3, patch_size, patch_size))
),
"patch_newline_mask": torch.tensor(
patch_newline_mask_lst, dtype=torch.bool
),
}
text = [
self.replace_placeholder(t, self.image_token, image_repl_str_lst)
for t in text
]
text_inputs = self.tokenizer(text)
text_inputs = {}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
data={**text_inputs, **image_inputs},
tensor_type=return_tensors,
)