[Model] Use helper function to run MM processors with token inputs (where applicable) (#38018)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-26 16:44:04 +08:00
committed by GitHub
parent 52069012fe
commit 502c41a8f6
12 changed files with 215 additions and 145 deletions

View File

@@ -1210,6 +1210,17 @@ class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path to use the
# video-specific logic in processing_keye.py
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@@ -371,6 +371,17 @@ class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path to use the
# video-specific logic in processing_keye.py
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@@ -215,6 +215,17 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
grid_thws=MultiModalFieldConfig.batched("vision_chunk"),
)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path because vision chunk
# is not considered
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@@ -12,7 +12,7 @@ import torch.nn.functional as F
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from transformers import PixtralVisionConfig
from transformers import BatchFeature, PixtralVisionConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
@@ -62,6 +62,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
from vllm.utils.collection_utils import is_list_of
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
@@ -213,6 +214,27 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
) -> Mapping[str, MultiModalFieldConfig]:
return dict(images=MultiModalFieldConfig.batched("image"))
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
# Avoid padding issue
tok_kwargs={**tok_kwargs, "return_tensors": None},
)
# Missing batch dimension
if is_list_of(outputs["input_ids"], int):
outputs["input_ids"] = [outputs["input_ids"]]
return outputs
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@@ -929,6 +929,17 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
)
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
# Override to use the text path instead of token path to use the
# video-specific logic in processing_qwen2_5_vl.py
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
from vllm.utils.collection_utils import is_list_of
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix
@@ -208,7 +209,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
) -> None:
# mistral_common's tokenizer's does not follow HF's placeholder norms
# skip validation here
...
pass
def _call_hf_processor(
self,
@@ -224,13 +225,20 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# MistralCommonVoxtralProcessor accepts "audio"
mm_data["audio"] = audios
return super()._call_hf_processor(
outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
# Avoid padding issue
tok_kwargs={**tok_kwargs, "return_tensors": None},
)
# Missing batch dimension
if is_list_of(outputs["input_ids"], int):
outputs["input_ids"] = [outputs["input_ids"]]
return outputs
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,

View File

@@ -196,7 +196,7 @@ class InputProcessingContext:
tokenizer = self.tokenizer
if is_mistral_tokenizer(tokenizer):
tokenizer = tokenizer.transformers_tokenizer
tokenizer = tokenizer.transformers_tokenizer # type: ignore[union-attr]
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
merged_kwargs.pop("tokenizer", None)
@@ -263,9 +263,10 @@ class InputProcessingContext:
requires_kw_only=False,
allow_var_kwargs=True,
)
allowed_kwargs.setdefault("return_tensors", "pt")
try:
output = hf_processor(**data, **allowed_kwargs, return_tensors="pt")
output = hf_processor(**data, **allowed_kwargs)
except Exception as exc:
# See https://github.com/huggingface/tokenizers/issues/537
if (

View File

@@ -5,8 +5,15 @@ from collections import defaultdict
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
from dataclasses import dataclass, field, replace
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Generic, NamedTuple, Protocol, TypeAlias, cast
from functools import lru_cache, partial
from typing import (
TYPE_CHECKING,
Generic,
NamedTuple,
Protocol,
TypeAlias,
cast,
)
import regex as re
import torch
@@ -21,6 +28,7 @@ from vllm.inputs import (
)
from vllm.logger import init_logger
from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.processor import call_hf_processor_mm_only
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
from ..inputs import (
@@ -1150,7 +1158,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
)
processed_data.update(passthrough_data)
(prompt_ids,) = processed_data.pop("input_ids").tolist()
input_ids = processed_data.pop("input_ids")
if not isinstance(input_ids, list):
input_ids = input_ids.tolist()
(prompt_ids,) = input_ids
is_update_applied = self._hf_processor_applies_updates(
prompt_text=prompt_text,
@@ -1213,16 +1225,35 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
[`DummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
to go along with the multi-modal data.
"""
mm_counts = mm_items.get_all_counts()
# Custom logic based on text inputs
if type(self)._call_hf_processor != BaseMultiModalProcessor._call_hf_processor:
mm_counts = mm_items.get_all_counts()
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
tokenization_kwargs=tokenization_kwargs,
)
return mm_processed_data
valid_mm_items = mm_items.select(
{k for k, c in mm_items.get_all_counts().items() if c > 0}
)
processor_data, passthrough_data = self._get_hf_mm_data(valid_mm_items)
return mm_processed_data
processed_data = self.info.ctx.call_hf_processor(
partial(
call_hf_processor_mm_only,
self.info.get_hf_processor(**hf_processor_mm_kwargs),
),
processor_data,
dict(**hf_processor_mm_kwargs, **tokenization_kwargs),
)
processed_data.update(passthrough_data)
return processed_data
def _apply_hf_processor_main(
self,

View File

@@ -11,12 +11,16 @@ from transformers import (
AutoImageProcessor,
AutoProcessor,
AutoVideoProcessor,
BatchFeature,
processing_utils,
)
from transformers.audio_utils import AudioInput
from transformers.feature_extraction_utils import FeatureExtractionMixin
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import ImageInput
from transformers.processing_utils import ProcessorMixin
from transformers.video_processing_utils import BaseVideoProcessor
from transformers.video_utils import VideoInput
from typing_extensions import TypeVar
from vllm.logger import init_logger
@@ -272,7 +276,6 @@ def get_processor_kwargs_keys(
"images_kwargs",
"videos_kwargs",
"audio_kwargs",
"common_kwargs",
}
try:
@@ -523,3 +526,43 @@ def cached_video_processor_from_config(
processor_cls_overrides=processor_cls, # type: ignore[arg-type]
**_merge_mm_kwargs(model_config, AutoVideoProcessor, **kwargs),
)
def call_hf_processor_mm_only(
processor: ProcessorMixin,
images: ImageInput | None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
) -> BatchFeature:
output_kwargs = processor._merge_kwargs(
get_processor_kwargs_type(processor),
**kwargs,
)
if audio is not None and (
feature_extractor := getattr(processor, "feature_extractor", None)
):
audio_inputs = feature_extractor(audio, **output_kwargs["audio_kwargs"])
audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask")
else:
audio_inputs = {}
if images is not None and (
image_processor := getattr(processor, "image_processor", None)
):
images_inputs = image_processor(images=images, **output_kwargs["images_kwargs"])
else:
images_inputs = {}
if videos is not None and (
video_processor := getattr(processor, "video_processor", None)
):
videos_inputs = video_processor(videos=videos, **output_kwargs["videos_kwargs"])
else:
videos_inputs = {}
return BatchFeature(
data={**audio_inputs, **images_inputs, **videos_inputs},
tensor_type=kwargs.get("return_tensors"),
)

View File

@@ -1,16 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import math
from typing import Any
from typing import Any, TypedDict
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from transformers import BatchFeature, ProcessorMixin, TensorType
from typing_extensions import TypedDict, Unpack
from transformers.processing_utils import ProcessingKwargs
from typing_extensions import Unpack
from vllm.tokenizers.hf import HfTokenizer
@@ -308,15 +307,22 @@ def process_vision_for_patches(
return patches, dims_virtual
class IsaacImageProcessorKwargs(TypedDict, total=False):
class IsaacImagesKwargs(TypedDict, total=False):
patch_size: int
max_num_patches: int
min_num_patches: int
pixel_shuffle_scale: int
class IsaacProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
images_kwargs: IsaacImagesKwargs
_defaults = {
"text_kwargs": {"padding": False},
"images_kwargs": {},
}
class IsaacImageProcessor:
valid_kwargs = IsaacImageProcessorKwargs
model_input_names = ["pixel_values", "image_grid_thw"]
def __init__(
@@ -335,7 +341,7 @@ class IsaacImageProcessor:
self,
images: Image.Image | list[Image.Image],
return_tensors: str | TensorType | None = None,
**kwargs: Unpack[IsaacImageProcessorKwargs],
**kwargs: Unpack[IsaacImagesKwargs],
) -> BatchFeature:
"""Preprocess images into format compatible with vLLM input processing."""
if not isinstance(images, list):
@@ -349,10 +355,16 @@ class IsaacImageProcessor:
patches, dims_virtual = process_vision_for_patches(
image_tensor,
patch_size=self.patch_size,
max_num_patches=self.vision_max_num_patches,
min_num_patches=self.vision_min_num_patches,
pixel_shuffle_scale=self.pixel_shuffle_scale,
patch_size=kwargs.get("patch_size", self.patch_size),
max_num_patches=kwargs.get(
"max_num_patches", self.vision_max_num_patches
),
min_num_patches=kwargs.get(
"min_num_patches", self.vision_min_num_patches
),
pixel_shuffle_scale=kwargs.get(
"pixel_shuffle_scale", self.pixel_shuffle_scale
),
)
# Isaac packs a dummy temporal dim for images
@@ -405,13 +417,17 @@ class IsaacProcessor(ProcessorMixin):
text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
**kwargs: Unpack[IsaacProcessorKwargs], # type: ignore[misc]
) -> BatchFeature:
output_kwargs = self._merge_kwargs(
IsaacProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
if images is not None:
image_inputs = self.image_processor(
images,
return_tensors=return_tensors,
**kwargs,
images, **output_kwargs["images_kwargs"]
)
image_grid_thw = image_inputs["image_grid_thw"]
else:
@@ -435,7 +451,7 @@ class IsaacProcessor(ProcessorMixin):
index += 1
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
text_inputs = self.tokenizer(text, return_tensors=return_tensors)
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
else:
text_inputs = {}

View File

@@ -5,10 +5,7 @@ from mistral_common.protocol.instruct.chunk import ImageChunk
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import BatchFeature, ProcessorMixin, TensorType
from transformers.audio_utils import AudioInput
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
from vllm.tokenizers.mistral import MistralTokenizer
@@ -55,62 +52,16 @@ class MistralCommonPixtralProcessor(ProcessorMixin):
def __init__(self, tokenizer: MistralTokenizer) -> None:
self.tokenizer = tokenizer.transformers_tokenizer
# Back-compatibility for Transformers v4
if not hasattr(self.tokenizer, "init_kwargs"):
self.tokenizer.init_kwargs = {}
self.image_processor = MistralCommonImageProcessor(
tokenizer.instruct.mm_encoder
)
self._image_special_ids = self.image_processor.mm_encoder.special_ids
@property
def image_break_id(self) -> int:
return self._image_special_ids.img_break
@property
def image_token_id(self) -> int:
return self._image_special_ids.img
@property
def image_end_id(self) -> int:
return self._image_special_ids.img_end
def __call__(
self,
images: ImageInput | None = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput]
| None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
):
if images is None and text is None and videos is None and audio is None:
raise ValueError(
f"You need to provide at least one input to "
f"call {self.__class__.__name__}"
)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs={},
**kwargs,
)
kwargs["text_kwargs"]["return_tensors"] = "pt"
kwargs["images_kwargs"]["return_tensors"] = None # Avoid padding issue
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
}
outputs = {}
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)
return BatchFeature(outputs)
image_special_ids = self.image_processor.mm_encoder.special_ids
self.image_break_id = image_special_ids.img_break
self.image_token_id = image_special_ids.img
self.image_end_id = image_special_ids.img_end

View File

@@ -8,9 +8,6 @@ import torch
from mistral_common.tokens.tokenizers.audio import AudioEncoder
from transformers import BatchFeature, ProcessorMixin, TensorType
from transformers.audio_utils import AudioInput
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
from transformers.video_utils import VideoInput
from vllm.tokenizers.mistral import MistralTokenizer
@@ -62,58 +59,15 @@ class MistralCommonVoxtralProcessor(ProcessorMixin):
def __init__(self, tokenizer: MistralTokenizer) -> None:
self.tokenizer = tokenizer.transformers_tokenizer
# Back-compatibility for Transformers v4
if not hasattr(self.tokenizer, "init_kwargs"):
self.tokenizer.init_kwargs = {}
self.feature_extractor = MistralCommonFeatureExtractor(
tokenizer.instruct.audio_encoder
)
self._audio_special_ids = self.feature_extractor.audio_encoder.special_ids
@property
def audio_token_id(self) -> int:
return self._audio_special_ids.audio
@property
def begin_audio_token_id(self) -> int:
return self._audio_special_ids.begin_audio
def __call__(
self,
images: ImageInput | None = None,
text: TextInput
| PreTokenizedInput
| list[TextInput]
| list[PreTokenizedInput]
| None = None,
videos: VideoInput | None = None,
audio: AudioInput | None = None,
**kwargs,
):
if images is None and text is None and videos is None and audio is None:
raise ValueError(
f"You need to provide at least one input to "
f"call {self.__class__.__name__}"
)
kwargs = self._merge_kwargs(
self.valid_processor_kwargs,
tokenizer_init_kwargs={},
**kwargs,
)
kwargs["text_kwargs"]["return_tensors"] = "pt"
kwargs["audio_kwargs"]["return_tensors"] = None # Avoid padding issue
attribute_to_kwargs = {
"tokenizer": (text, "text_kwargs"),
"image_processor": (images, "images_kwargs"),
"video_processor": (videos, "videos_kwargs"),
"feature_extractor": (audio, "audio_kwargs"),
}
outputs = {}
for attribute_name in self.attributes:
attribute = getattr(self, attribute_name, None)
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
if input_data is not None and attribute is not None:
attribute_output = attribute(input_data, **kwargs[input_kwargs])
outputs.update(attribute_output)
return BatchFeature(outputs)
audio_special_ids = self.feature_extractor.audio_encoder.special_ids
self.audio_token_id = audio_special_ids.audio
self.begin_audio_token_id = audio_special_ids.begin_audio