[Model] Use helper function to run MM processors with token inputs (where applicable) (#38018)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1210,6 +1210,17 @@ class KeyeDummyInputsBuilder(KeyeBaseDummyInputsBuilder[KeyeProcessingInfo]):
|
||||
|
||||
|
||||
class KeyeMultiModalProcessor(BaseMultiModalProcessor[KeyeProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Override to use the text path instead of token path to use the
|
||||
# video-specific logic in processing_keye.py
|
||||
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@@ -371,6 +371,17 @@ class KeyeVL1_5ProcessingInfo(KeyeProcessingInfo):
|
||||
|
||||
|
||||
class KeyeVL1_5MultiModalProcessor(BaseMultiModalProcessor[KeyeVL1_5ProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Override to use the text path instead of token path to use the
|
||||
# video-specific logic in processing_keye.py
|
||||
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@@ -215,6 +215,17 @@ class KimiK25MultiModalProcessor(BaseMultiModalProcessor[KimiK25ProcessingInfo])
|
||||
grid_thws=MultiModalFieldConfig.batched("vision_chunk"),
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Override to use the text path instead of token path because vision chunk
|
||||
# is not considered
|
||||
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@@ -12,7 +12,7 @@ import torch.nn.functional as F
|
||||
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from transformers import PixtralVisionConfig
|
||||
from transformers import BatchFeature, PixtralVisionConfig
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
@@ -62,6 +62,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
@@ -213,6 +214,27 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(images=MultiModalFieldConfig.batched("image"))
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
# Avoid padding issue
|
||||
tok_kwargs={**tok_kwargs, "return_tensors": None},
|
||||
)
|
||||
|
||||
# Missing batch dimension
|
||||
if is_list_of(outputs["input_ids"], int):
|
||||
outputs["input_ids"] = [outputs["input_ids"]]
|
||||
|
||||
return outputs
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@@ -929,6 +929,17 @@ class Qwen2_5_VLMultiModalProcessor(Qwen2VLMultiModalProcessor):
|
||||
second_per_grid_ts=MultiModalFieldConfig.batched("video"),
|
||||
)
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Override to use the text path instead of token path to use the
|
||||
# video-specific logic in processing_qwen2_5_vl.py
|
||||
return super()._call_hf_processor(prompt, mm_data, mm_kwargs, tok_kwargs)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@@ -56,6 +56,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
|
||||
from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
||||
from .utils import init_vllm_registered_model, maybe_prefix
|
||||
@@ -208,7 +209,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
||||
) -> None:
|
||||
# mistral_common's tokenizer's does not follow HF's placeholder norms
|
||||
# skip validation here
|
||||
...
|
||||
pass
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
@@ -224,13 +225,20 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
||||
# MistralCommonVoxtralProcessor accepts "audio"
|
||||
mm_data["audio"] = audios
|
||||
|
||||
return super()._call_hf_processor(
|
||||
outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
# Avoid padding issue
|
||||
tok_kwargs={**tok_kwargs, "return_tensors": None},
|
||||
)
|
||||
|
||||
# Missing batch dimension
|
||||
if is_list_of(outputs["input_ids"], int):
|
||||
outputs["input_ids"] = [outputs["input_ids"]]
|
||||
|
||||
return outputs
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
|
||||
@@ -196,7 +196,7 @@ class InputProcessingContext:
|
||||
|
||||
tokenizer = self.tokenizer
|
||||
if is_mistral_tokenizer(tokenizer):
|
||||
tokenizer = tokenizer.transformers_tokenizer
|
||||
tokenizer = tokenizer.transformers_tokenizer # type: ignore[union-attr]
|
||||
|
||||
merged_kwargs = self.get_merged_mm_kwargs(kwargs)
|
||||
merged_kwargs.pop("tokenizer", None)
|
||||
@@ -263,9 +263,10 @@ class InputProcessingContext:
|
||||
requires_kw_only=False,
|
||||
allow_var_kwargs=True,
|
||||
)
|
||||
allowed_kwargs.setdefault("return_tensors", "pt")
|
||||
|
||||
try:
|
||||
output = hf_processor(**data, **allowed_kwargs, return_tensors="pt")
|
||||
output = hf_processor(**data, **allowed_kwargs)
|
||||
except Exception as exc:
|
||||
# See https://github.com/huggingface/tokenizers/issues/537
|
||||
if (
|
||||
|
||||
@@ -5,8 +5,15 @@ from collections import defaultdict
|
||||
from collections.abc import Callable, Generator, ItemsView, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field, replace
|
||||
from enum import Enum
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Generic, NamedTuple, Protocol, TypeAlias, cast
|
||||
from functools import lru_cache, partial
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Generic,
|
||||
NamedTuple,
|
||||
Protocol,
|
||||
TypeAlias,
|
||||
cast,
|
||||
)
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
@@ -21,6 +28,7 @@ from vllm.inputs import (
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.tokenizers import TokenizerLike
|
||||
from vllm.transformers_utils.processor import call_hf_processor_mm_only
|
||||
from vllm.utils.collection_utils import flatten_2d_lists, full_groupby
|
||||
|
||||
from ..inputs import (
|
||||
@@ -1150,7 +1158,11 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
)
|
||||
processed_data.update(passthrough_data)
|
||||
|
||||
(prompt_ids,) = processed_data.pop("input_ids").tolist()
|
||||
input_ids = processed_data.pop("input_ids")
|
||||
if not isinstance(input_ids, list):
|
||||
input_ids = input_ids.tolist()
|
||||
|
||||
(prompt_ids,) = input_ids
|
||||
|
||||
is_update_applied = self._hf_processor_applies_updates(
|
||||
prompt_text=prompt_text,
|
||||
@@ -1213,16 +1225,35 @@ class BaseMultiModalProcessor(ABC, Generic[_I]):
|
||||
[`DummyInputsBuilder`][vllm.multimodal.processing.BaseDummyInputsBuilder]
|
||||
to go along with the multi-modal data.
|
||||
"""
|
||||
mm_counts = mm_items.get_all_counts()
|
||||
# Custom logic based on text inputs
|
||||
if type(self)._call_hf_processor != BaseMultiModalProcessor._call_hf_processor:
|
||||
mm_counts = mm_items.get_all_counts()
|
||||
|
||||
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
_, mm_processed_data, _ = self._apply_hf_processor_text_mm(
|
||||
prompt_text=self.dummy_inputs.get_dummy_text(mm_counts),
|
||||
mm_items=mm_items,
|
||||
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
)
|
||||
|
||||
return mm_processed_data
|
||||
|
||||
valid_mm_items = mm_items.select(
|
||||
{k for k, c in mm_items.get_all_counts().items() if c > 0}
|
||||
)
|
||||
processor_data, passthrough_data = self._get_hf_mm_data(valid_mm_items)
|
||||
|
||||
return mm_processed_data
|
||||
processed_data = self.info.ctx.call_hf_processor(
|
||||
partial(
|
||||
call_hf_processor_mm_only,
|
||||
self.info.get_hf_processor(**hf_processor_mm_kwargs),
|
||||
),
|
||||
processor_data,
|
||||
dict(**hf_processor_mm_kwargs, **tokenization_kwargs),
|
||||
)
|
||||
processed_data.update(passthrough_data)
|
||||
|
||||
return processed_data
|
||||
|
||||
def _apply_hf_processor_main(
|
||||
self,
|
||||
|
||||
@@ -11,12 +11,16 @@ from transformers import (
|
||||
AutoImageProcessor,
|
||||
AutoProcessor,
|
||||
AutoVideoProcessor,
|
||||
BatchFeature,
|
||||
processing_utils,
|
||||
)
|
||||
from transformers.audio_utils import AudioInput
|
||||
from transformers.feature_extraction_utils import FeatureExtractionMixin
|
||||
from transformers.image_processing_utils import BaseImageProcessor
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.processing_utils import ProcessorMixin
|
||||
from transformers.video_processing_utils import BaseVideoProcessor
|
||||
from transformers.video_utils import VideoInput
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from vllm.logger import init_logger
|
||||
@@ -272,7 +276,6 @@ def get_processor_kwargs_keys(
|
||||
"images_kwargs",
|
||||
"videos_kwargs",
|
||||
"audio_kwargs",
|
||||
"common_kwargs",
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -523,3 +526,43 @@ def cached_video_processor_from_config(
|
||||
processor_cls_overrides=processor_cls, # type: ignore[arg-type]
|
||||
**_merge_mm_kwargs(model_config, AutoVideoProcessor, **kwargs),
|
||||
)
|
||||
|
||||
|
||||
def call_hf_processor_mm_only(
|
||||
processor: ProcessorMixin,
|
||||
images: ImageInput | None = None,
|
||||
videos: VideoInput | None = None,
|
||||
audio: AudioInput | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
output_kwargs = processor._merge_kwargs(
|
||||
get_processor_kwargs_type(processor),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if audio is not None and (
|
||||
feature_extractor := getattr(processor, "feature_extractor", None)
|
||||
):
|
||||
audio_inputs = feature_extractor(audio, **output_kwargs["audio_kwargs"])
|
||||
audio_inputs["feature_attention_mask"] = audio_inputs.pop("attention_mask")
|
||||
else:
|
||||
audio_inputs = {}
|
||||
|
||||
if images is not None and (
|
||||
image_processor := getattr(processor, "image_processor", None)
|
||||
):
|
||||
images_inputs = image_processor(images=images, **output_kwargs["images_kwargs"])
|
||||
else:
|
||||
images_inputs = {}
|
||||
|
||||
if videos is not None and (
|
||||
video_processor := getattr(processor, "video_processor", None)
|
||||
):
|
||||
videos_inputs = video_processor(videos=videos, **output_kwargs["videos_kwargs"])
|
||||
else:
|
||||
videos_inputs = {}
|
||||
|
||||
return BatchFeature(
|
||||
data={**audio_inputs, **images_inputs, **videos_inputs},
|
||||
tensor_type=kwargs.get("return_tensors"),
|
||||
)
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Any
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, ProcessorMixin, TensorType
|
||||
from typing_extensions import TypedDict, Unpack
|
||||
from transformers.processing_utils import ProcessingKwargs
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from vllm.tokenizers.hf import HfTokenizer
|
||||
|
||||
@@ -308,15 +307,22 @@ def process_vision_for_patches(
|
||||
return patches, dims_virtual
|
||||
|
||||
|
||||
class IsaacImageProcessorKwargs(TypedDict, total=False):
|
||||
class IsaacImagesKwargs(TypedDict, total=False):
|
||||
patch_size: int
|
||||
max_num_patches: int
|
||||
min_num_patches: int
|
||||
pixel_shuffle_scale: int
|
||||
|
||||
|
||||
class IsaacProcessorKwargs(ProcessingKwargs, total=False): # type: ignore[call-arg]
|
||||
images_kwargs: IsaacImagesKwargs
|
||||
_defaults = {
|
||||
"text_kwargs": {"padding": False},
|
||||
"images_kwargs": {},
|
||||
}
|
||||
|
||||
|
||||
class IsaacImageProcessor:
|
||||
valid_kwargs = IsaacImageProcessorKwargs
|
||||
model_input_names = ["pixel_values", "image_grid_thw"]
|
||||
|
||||
def __init__(
|
||||
@@ -335,7 +341,7 @@ class IsaacImageProcessor:
|
||||
self,
|
||||
images: Image.Image | list[Image.Image],
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs: Unpack[IsaacImageProcessorKwargs],
|
||||
**kwargs: Unpack[IsaacImagesKwargs],
|
||||
) -> BatchFeature:
|
||||
"""Preprocess images into format compatible with vLLM input processing."""
|
||||
if not isinstance(images, list):
|
||||
@@ -349,10 +355,16 @@ class IsaacImageProcessor:
|
||||
|
||||
patches, dims_virtual = process_vision_for_patches(
|
||||
image_tensor,
|
||||
patch_size=self.patch_size,
|
||||
max_num_patches=self.vision_max_num_patches,
|
||||
min_num_patches=self.vision_min_num_patches,
|
||||
pixel_shuffle_scale=self.pixel_shuffle_scale,
|
||||
patch_size=kwargs.get("patch_size", self.patch_size),
|
||||
max_num_patches=kwargs.get(
|
||||
"max_num_patches", self.vision_max_num_patches
|
||||
),
|
||||
min_num_patches=kwargs.get(
|
||||
"min_num_patches", self.vision_min_num_patches
|
||||
),
|
||||
pixel_shuffle_scale=kwargs.get(
|
||||
"pixel_shuffle_scale", self.pixel_shuffle_scale
|
||||
),
|
||||
)
|
||||
|
||||
# Isaac packs a dummy temporal dim for images
|
||||
@@ -405,13 +417,17 @@ class IsaacProcessor(ProcessorMixin):
|
||||
text: str | list[str] | None = None,
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[IsaacProcessorKwargs], # type: ignore[misc]
|
||||
) -> BatchFeature:
|
||||
output_kwargs = self._merge_kwargs(
|
||||
IsaacProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if images is not None:
|
||||
image_inputs = self.image_processor(
|
||||
images,
|
||||
return_tensors=return_tensors,
|
||||
**kwargs,
|
||||
images, **output_kwargs["images_kwargs"]
|
||||
)
|
||||
image_grid_thw = image_inputs["image_grid_thw"]
|
||||
else:
|
||||
@@ -435,7 +451,7 @@ class IsaacProcessor(ProcessorMixin):
|
||||
index += 1
|
||||
text[i] = text[i].replace("<|placeholder|>", "<|image_pad|>")
|
||||
|
||||
text_inputs = self.tokenizer(text, return_tensors=return_tensors)
|
||||
text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
else:
|
||||
text_inputs = {}
|
||||
|
||||
|
||||
@@ -5,10 +5,7 @@ from mistral_common.protocol.instruct.chunk import ImageChunk
|
||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, ProcessorMixin, TensorType
|
||||
from transformers.audio_utils import AudioInput
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
@@ -55,62 +52,16 @@ class MistralCommonPixtralProcessor(ProcessorMixin):
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
||||
self.tokenizer = tokenizer.transformers_tokenizer
|
||||
|
||||
# Back-compatibility for Transformers v4
|
||||
if not hasattr(self.tokenizer, "init_kwargs"):
|
||||
self.tokenizer.init_kwargs = {}
|
||||
|
||||
self.image_processor = MistralCommonImageProcessor(
|
||||
tokenizer.instruct.mm_encoder
|
||||
)
|
||||
|
||||
self._image_special_ids = self.image_processor.mm_encoder.special_ids
|
||||
|
||||
@property
|
||||
def image_break_id(self) -> int:
|
||||
return self._image_special_ids.img_break
|
||||
|
||||
@property
|
||||
def image_token_id(self) -> int:
|
||||
return self._image_special_ids.img
|
||||
|
||||
@property
|
||||
def image_end_id(self) -> int:
|
||||
return self._image_special_ids.img_end
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput | None = None,
|
||||
text: TextInput
|
||||
| PreTokenizedInput
|
||||
| list[TextInput]
|
||||
| list[PreTokenizedInput]
|
||||
| None = None,
|
||||
videos: VideoInput | None = None,
|
||||
audio: AudioInput | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if images is None and text is None and videos is None and audio is None:
|
||||
raise ValueError(
|
||||
f"You need to provide at least one input to "
|
||||
f"call {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
kwargs = self._merge_kwargs(
|
||||
self.valid_processor_kwargs,
|
||||
tokenizer_init_kwargs={},
|
||||
**kwargs,
|
||||
)
|
||||
kwargs["text_kwargs"]["return_tensors"] = "pt"
|
||||
kwargs["images_kwargs"]["return_tensors"] = None # Avoid padding issue
|
||||
|
||||
attribute_to_kwargs = {
|
||||
"tokenizer": (text, "text_kwargs"),
|
||||
"image_processor": (images, "images_kwargs"),
|
||||
"video_processor": (videos, "videos_kwargs"),
|
||||
"feature_extractor": (audio, "audio_kwargs"),
|
||||
}
|
||||
outputs = {}
|
||||
for attribute_name in self.attributes:
|
||||
attribute = getattr(self, attribute_name, None)
|
||||
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
|
||||
if input_data is not None and attribute is not None:
|
||||
attribute_output = attribute(input_data, **kwargs[input_kwargs])
|
||||
outputs.update(attribute_output)
|
||||
|
||||
return BatchFeature(outputs)
|
||||
image_special_ids = self.image_processor.mm_encoder.special_ids
|
||||
self.image_break_id = image_special_ids.img_break
|
||||
self.image_token_id = image_special_ids.img
|
||||
self.image_end_id = image_special_ids.img_end
|
||||
|
||||
@@ -8,9 +8,6 @@ import torch
|
||||
from mistral_common.tokens.tokenizers.audio import AudioEncoder
|
||||
from transformers import BatchFeature, ProcessorMixin, TensorType
|
||||
from transformers.audio_utils import AudioInput
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import PreTokenizedInput, TextInput
|
||||
from transformers.video_utils import VideoInput
|
||||
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
@@ -62,58 +59,15 @@ class MistralCommonVoxtralProcessor(ProcessorMixin):
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
||||
self.tokenizer = tokenizer.transformers_tokenizer
|
||||
|
||||
# Back-compatibility for Transformers v4
|
||||
if not hasattr(self.tokenizer, "init_kwargs"):
|
||||
self.tokenizer.init_kwargs = {}
|
||||
|
||||
self.feature_extractor = MistralCommonFeatureExtractor(
|
||||
tokenizer.instruct.audio_encoder
|
||||
)
|
||||
|
||||
self._audio_special_ids = self.feature_extractor.audio_encoder.special_ids
|
||||
|
||||
@property
|
||||
def audio_token_id(self) -> int:
|
||||
return self._audio_special_ids.audio
|
||||
|
||||
@property
|
||||
def begin_audio_token_id(self) -> int:
|
||||
return self._audio_special_ids.begin_audio
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
images: ImageInput | None = None,
|
||||
text: TextInput
|
||||
| PreTokenizedInput
|
||||
| list[TextInput]
|
||||
| list[PreTokenizedInput]
|
||||
| None = None,
|
||||
videos: VideoInput | None = None,
|
||||
audio: AudioInput | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if images is None and text is None and videos is None and audio is None:
|
||||
raise ValueError(
|
||||
f"You need to provide at least one input to "
|
||||
f"call {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
kwargs = self._merge_kwargs(
|
||||
self.valid_processor_kwargs,
|
||||
tokenizer_init_kwargs={},
|
||||
**kwargs,
|
||||
)
|
||||
kwargs["text_kwargs"]["return_tensors"] = "pt"
|
||||
kwargs["audio_kwargs"]["return_tensors"] = None # Avoid padding issue
|
||||
|
||||
attribute_to_kwargs = {
|
||||
"tokenizer": (text, "text_kwargs"),
|
||||
"image_processor": (images, "images_kwargs"),
|
||||
"video_processor": (videos, "videos_kwargs"),
|
||||
"feature_extractor": (audio, "audio_kwargs"),
|
||||
}
|
||||
outputs = {}
|
||||
for attribute_name in self.attributes:
|
||||
attribute = getattr(self, attribute_name, None)
|
||||
input_data, input_kwargs = attribute_to_kwargs[attribute_name]
|
||||
if input_data is not None and attribute is not None:
|
||||
attribute_output = attribute(input_data, **kwargs[input_kwargs])
|
||||
outputs.update(attribute_output)
|
||||
|
||||
return BatchFeature(outputs)
|
||||
audio_special_ids = self.feature_extractor.audio_encoder.special_ids
|
||||
self.audio_token_id = audio_special_ids.audio
|
||||
self.begin_audio_token_id = audio_special_ids.begin_audio
|
||||
|
||||
Reference in New Issue
Block a user