diff --git a/vllm/model_executor/models/glm4v.py b/vllm/model_executor/models/glm4v.py index 3513419cb..959839e77 100644 --- a/vllm/model_executor/models/glm4v.py +++ b/vllm/model_executor/models/glm4v.py @@ -13,11 +13,7 @@ import numpy as np import torch from torch import nn from torch.nn import LayerNorm -from torchvision import transforms -from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PreTrainedTokenizer, TensorType -from transformers.image_utils import ImageInput -from transformers.tokenization_utils_base import TextInput +from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -50,7 +46,8 @@ from vllm.multimodal.processing import ( PromptUpdate, ) from vllm.sequence import IntermediateTensors -from vllm.transformers_utils.configs import ChatGLMConfig +from vllm.transformers_utils.configs.chatglm import ChatGLMConfig +from vllm.transformers_utils.processors.glm4v import GLM4VProcessor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer @@ -386,81 +383,19 @@ class GLM4VModel(ChatGLMModel): ) -class GLM4VProcessor: - """ - This model doesn't define its own HF processor, - so we implement our own one here. - """ - - def __init__( - self, - config: ChatGLMConfig, - tokenizer: PreTrainedTokenizer, - ) -> None: - super().__init__() - - self.config = config - self.tokenizer = tokenizer - - vision_config = config.vision_config - image_size = vision_config["image_size"] - - self.image_transform = transforms.Compose( - [ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ] - ) - - def __call__( - self, - text: TextInput | list[TextInput] | None = None, - images: ImageInput | list[ImageInput] | None = None, - return_tensors: str | TensorType | None = None, - ) -> BatchFeature: - if text is None: - text = [] - if not isinstance(text, list): - text = [text] - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - - text_inputs = self.tokenizer(text) - - if len(images) == 0: - image_inputs = {} - else: - pixel_values = [self.image_transform(image) for image in images] - image_inputs = {"pixel_values": torch.stack(pixel_values)} - - return BatchFeature( - { - **text_inputs, - **image_inputs, - }, - tensor_type=return_tensors, - ) - - class GLM4VProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(ChatGLMConfig) def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor: + config = self.get_hf_config() + vision_config = config.vision_config + image_size = vision_config["image_size"] + return self.ctx.init_processor( GLM4VProcessor, - config=self.get_hf_config(), tokenizer=self.get_tokenizer(), - **kwargs, + **{**kwargs, "image_size": image_size}, ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index ba6d569b7..faac00a4e 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -4,7 +4,7 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass -from functools import cached_property, partial +from functools import partial from itertools import islice from typing import Annotated @@ -13,9 +13,11 @@ import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType -from transformers.image_utils import ImageInput -from transformers.tokenization_utils_base import TextInput +from transformers import ( + BaseImageProcessor, + BatchFeature, + PretrainedConfig, +) from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig @@ -1017,117 +1019,28 @@ def select_tiling( return candidate_tilings[ix] -class MolmoProcessorWrapper: - """ - Wraps `MolmoProcessor` so that it can be called directly. +def _as_2tuple(x: int | tuple[int, int]) -> tuple[int, int]: + if isinstance(x, int): + return x, x - The original definition can be found here: - https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py - """ + return x - def __init__(self, processor: ProcessorMixin): - super().__init__() - self.processor = processor - - @cached_property - def vocab(self) -> dict[str, int]: - return self.processor.tokenizer.vocab # type: ignore - - @cached_property - def max_crops(self) -> int: - image_processor = self.processor.image_processor # type: ignore - - max_crops = image_processor.max_crops - assert isinstance(max_crops, int) - - return max_crops - - @cached_property - def base_image_input_size(self) -> tuple[int, int]: - image_processor = self.processor.image_processor # type: ignore - - base_image_input_size = image_processor.base_image_input_size - if isinstance(base_image_input_size, int): - return base_image_input_size, base_image_input_size - - return tuple(base_image_input_size) - - @cached_property - def image_patch_size(self) -> int: - image_processor = self.processor.image_processor # type: ignore - - image_patch_size = image_processor.image_patch_size - assert isinstance(image_patch_size, int) - - return image_patch_size - - @cached_property - def overlap_margins(self) -> tuple[int, int]: - image_processor = self.processor.image_processor # type: ignore - - left_margin, right_margin = image_processor.overlap_margins - assert isinstance(left_margin, int) - assert isinstance(right_margin, int) - - return left_margin, right_margin - - @cached_property - def image_token_length_w(self) -> int: - image_processor = self.processor.image_processor # type: ignore - - image_token_length_w = image_processor.image_token_length_w - assert isinstance(image_token_length_w, int) - - return image_token_length_w - - @cached_property - def image_token_length_h(self) -> int: - image_processor = self.processor.image_processor # type: ignore - - image_token_length_h = image_processor.image_token_length_h - assert isinstance(image_token_length_h, int) - - return image_token_length_h - - @property - def message_format(self) -> str | None: - return "role" - - @property - def always_start_with_space(self) -> bool: - return True - - @cached_property - def image_patch_id(self) -> int: - return self.vocab[IMAGE_PATCH_TOKEN] - - @cached_property - def im_col_id(self) -> int: - return self.vocab[IM_COL_TOKEN] - - @cached_property - def im_start_id(self) -> int: - return self.vocab[IM_START_TOKEN] - - @cached_property - def im_end_id(self) -> int: - return self.vocab[IM_END_TOKEN] - - @property - def pooling_size(self) -> int: - return POOLING_SIZE +class MolmoProcessingInfo(BaseProcessingInfo): + def get_supported_mm_limits(self) -> Mapping[str, int | None]: + return {"image": None} def select_tiling( self, *, image_width: int, image_height: int, + image_processor: BaseImageProcessor, ) -> tuple[int, int]: - max_crops = self.max_crops - left_margin, right_margin = self.overlap_margins - base_image_input_size = self.base_image_input_size - base_image_input_d = self.image_patch_size + max_crops = image_processor.max_crops + left_margin, right_margin = image_processor.overlap_margins + base_image_input_size = _as_2tuple(image_processor.base_image_input_size) + base_image_input_d = image_processor.image_patch_size total_margin_pixels = base_image_input_d * (right_margin + left_margin) crop_patches = base_image_input_size[0] // base_image_input_d @@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper: *, image_width: int, image_height: int, + image_processor: BaseImageProcessor, ) -> tuple[int, int]: - left_margin, right_margin = self.overlap_margins - base_image_input_size = self.base_image_input_size - base_image_input_d = self.image_patch_size - pooling_size = self.pooling_size + left_margin, right_margin = image_processor.overlap_margins + base_image_input_size = _as_2tuple(image_processor.base_image_input_size) + base_image_input_d = image_processor.image_patch_size + pooling_size = POOLING_SIZE crop_patches = base_image_input_size[0] // base_image_input_d tiling_w, tiling_h = self.select_tiling( image_height=image_height, image_width=image_width, + image_processor=image_processor, ) nrows, ncols = get_patches_grid_size( @@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper: return ncols, nrows - def __call__( - self, - text: TextInput | list[TextInput] | None = None, - images: ImageInput | list[ImageInput] | None = None, - return_tensors: str | TensorType | None = None, - **kwargs, - ) -> BatchFeature: - outputs = self.processor.process( # type: ignore - text, images, **kwargs - ) - - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - - input_ids: torch.Tensor = outputs.pop("input_ids") - outputs["input_ids"] = input_ids.unsqueeze(0) - - image_input_idx = outputs.pop("image_input_idx", None) - if image_input_idx is not None: - feat_is_patch = image_input_idx >= 0 - - tilings = [ - self.select_tiling( - image_width=image.size[0], - image_height=image.size[1], - ) - for image in images - ] - # For each image: tiling_h * tiling_w + extra - num_crops = torch.tensor(tilings).prod(-1) + 1 - assert num_crops.sum() == len(feat_is_patch) - - outputs["image_input_idx"] = image_input_idx - outputs["num_crops"] = num_crops - outputs["img_patch_id"] = self.image_patch_id - - return BatchFeature(outputs) - - -class MolmoProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper: - processor = self.ctx.get_hf_processor(**kwargs) - return MolmoProcessorWrapper(processor) - - def get_supported_mm_limits(self) -> Mapping[str, int | None]: - return {"image": None} - def get_num_image_tokens( self, *, image_width: int, image_height: int, - processor: MolmoProcessorWrapper, + image_processor: BaseImageProcessor, ) -> int: - ncols, nrows = processor.get_patches_grid_size( + ncols, nrows = self.get_patches_grid_size( image_width=image_width, image_height=image_height, + image_processor=image_processor, ) - pooling_size = processor.pooling_size + pooling_size = POOLING_SIZE - image_token_length_w = processor.image_token_length_w - image_token_length_h = processor.image_token_length_h + image_token_length_w = image_processor.image_token_length_w + image_token_length_h = image_processor.image_token_length_h # Calculate total tokens: 2 for start/end + (w+1)*h for column separators extra = 2 + (image_token_length_w + 1) * image_token_length_h @@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo): def get_image_size_with_most_features(self) -> ImageSize: processor = self.get_hf_processor() + image_processor = processor.image_processor - tilings = get_candidate_tilings(processor.max_crops) - base_h, base_w = processor.base_image_input_size + tilings = get_candidate_tilings(image_processor.max_crops) + base_h, base_w = _as_2tuple(image_processor.base_image_input_size) largest_feature_size, largest_feature_pinpoint = 0, None for wr, hr in tilings: @@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo): feat_size = self.get_num_image_tokens( image_width=width, image_height=height, - processor=processor, + image_processor=image_processor, ) if feat_size > largest_feature_size: largest_feature_size = feat_size @@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]): class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): + def _call_hf_processor( + self, + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], + ) -> BatchFeature: + hf_processor = self.info.get_hf_processor(**mm_kwargs) + processed_outputs = self.info.ctx.call_hf_processor( + hf_processor.process, + dict(text=prompt, **mm_data), + dict(**mm_kwargs, **tok_kwargs), + ) + + tokenizer = hf_processor.tokenizer + image_patch_id = tokenizer.vocab[IMAGE_PATCH_TOKEN] + + image_processor = hf_processor.image_processor + + input_ids: torch.Tensor = processed_outputs.pop("input_ids") + processed_outputs["input_ids"] = input_ids.unsqueeze(0) + + if (images := mm_data.get("images")) is not None: + mm_items = self.info.parse_mm_data({"image": images}, validate=False) + parsed_images = mm_items.get_items("image", ImageProcessorItems) + image_sizes = [ + parsed_images.get_image_size(i) for i in range(len(parsed_images)) + ] + + feat_is_patch = processed_outputs["image_input_idx"] >= 0 + + tilings = [ + self.info.select_tiling( + image_width=image_size.width, + image_height=image_size.height, + image_processor=image_processor, + ) + for image_size in image_sizes + ] + # For each image: tiling_h * tiling_w + extra + num_crops = torch.tensor(tilings).prod(-1) + 1 + assert num_crops.sum() == len(feat_is_patch) + + processed_outputs["num_crops"] = num_crops + processed_outputs["img_patch_id"] = image_patch_id + + return processed_outputs + def _apply_hf_processor_tokens_only( self, prompt_tokens: list[int], @@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): # The chat template is already applied to the prompt tokens # Use message_format="none" to avoid applying it again # Prepend an empty space if `always_start_with_space` is True - tokens = processor.processor.get_tokens_input( # type: ignore + tokens = processor.get_tokens_input( self.info.get_tokenizer().decode(prompt_tokens), message_format="none", - always_start_with_space=processor.always_start_with_space, + always_start_with_space=True, ) # Prepend a BOS token id to the tokens processed_data = self.info.ctx.call_hf_processor( - processor, # type: ignore + processor.process, dict(tokens=tokens), ) - (prompt_ids,) = processed_data.pop("input_ids").tolist() + prompt_ids = processed_data.pop("input_ids").tolist() + print(prompt_ids, len(prompt_ids)) return prompt_ids @@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): hf_processor_mm_kwargs: Mapping[str, object], out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: + tokenizer = self.info.get_tokenizer() + vocab = tokenizer.get_vocab() + img_patch_id = vocab[IMAGE_PATCH_TOKEN] + img_col_id = vocab[IM_COL_TOKEN] + img_start_id = vocab[IM_START_TOKEN] + img_end_id = vocab[IM_END_TOKEN] + processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - - image_token_length_w = processor.image_token_length_w - image_token_length_h = processor.image_token_length_h - pooling_size = processor.pooling_size - - img_patch_id = processor.image_patch_id - img_col_id = processor.im_col_id - img_start_id = processor.im_start_id - img_end_id = processor.im_end_id + image_processor = processor.image_processor + image_token_length_w = image_processor.image_token_length_w + image_token_length_h = image_processor.image_token_length_h + pooling_size = POOLING_SIZE extra_row = [img_patch_id] * image_token_length_w + [img_col_id] extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id] @@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]): images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - ncols, nrows = processor.get_patches_grid_size( + ncols, nrows = self.info.get_patches_grid_size( image_width=image_size.width, image_height=image_size.height, + image_processor=image_processor, ) joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id] diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index ebcc5d8b8..43e95c67a 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -4,7 +4,6 @@ import math from collections.abc import Iterable, Mapping, Sequence from dataclasses import dataclass, fields -from functools import cached_property from typing import Annotated, Literal import torch @@ -13,10 +12,7 @@ import torch.nn.functional as F from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest -from mistral_common.tokens.tokenizers.multimodal import ImageEncoder -from PIL import Image -from transformers import BatchFeature, PixtralVisionConfig, TensorType -from transformers.image_utils import ImageInput +from transformers import PixtralVisionConfig from transformers.models.pixtral.image_processing_pixtral import ( _num_image_tokens as _get_pixtral_hf_num_image_tokens, ) @@ -25,7 +21,6 @@ from transformers.models.pixtral.modeling_pixtral import ( apply_rotary_pos_emb, position_ids_in_meshgrid, ) -from transformers.tokenization_utils_base import TextInput from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -66,6 +61,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers.mistral import MistralTokenizer +from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -121,93 +117,6 @@ class PixtralImagePixelInputs(TensorSchema): ] -class PixtralProcessorAdapter: - """ - Provide a HF-compatible interface for - `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`. - """ - - def __init__(self, tokenizer: MistralTokenizer) -> None: - super().__init__() - - self.tokenizer = tokenizer - - @property - def image_processor(self) -> ImageEncoder: - image_encoder = self.tokenizer.instruct.mm_encoder - assert isinstance(image_encoder, ImageEncoder) - return image_encoder - - @cached_property - def image_break_id(self) -> int: - return self.image_processor.special_ids.img_break - - @cached_property - def image_token_id(self) -> int: - return self.image_processor.special_ids.img - - @cached_property - def image_end_id(self) -> int: - return self.image_processor.special_ids.img_end - - @cached_property - def image_size(self) -> int: - return self.image_processor.mm_config.max_image_size - - @cached_property - def patch_size(self) -> int: - return self.image_processor.mm_config.image_patch_size - - def __call__( - self, - text: TextInput | list[TextInput] | None = None, - images: ImageInput | list[ImageInput] | None = None, - return_tensors: str | TensorType | None = None, - **kwargs, - ) -> Mapping[str, NestedTensors]: - if text is None: - text = [] - if not isinstance(text, list): - text = [text] - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - - if not images: - input_ids = self.tokenizer(text).input_ids - - return {"input_ids": torch.tensor(input_ids)} - - # Allow dummy text, which is used for profiling as well as token inputs - if any(len(t) > 0 for t in text): - raise ValueError( - "You've passed text inputs instead of token inputs. " - "Make sure to process your input via `mistral_common`'s " - "tokenizer or pass a chat completion request. " - "For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411." - ) - - images_processed = list[torch.Tensor]() - images_tokens = list[torch.Tensor]() - - for image in images: - image_inputs = self.image_processor(ImageChunk(image=image)) - image_processed = torch.tensor(image_inputs.image) - image_tokens = torch.tensor(image_inputs.tokens) - - images_processed.append(image_processed) - images_tokens.append(image_tokens) - - return BatchFeature( - { - "input_ids": torch.cat(images_tokens)[None].expand(len(text), -1), - "images": images_processed, - } - ) - - class PixtralProcessingInfo(BaseProcessingInfo): def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) @@ -216,28 +125,19 @@ class PixtralProcessingInfo(BaseProcessingInfo): return tokenizer - def get_hf_processor(self) -> PixtralProcessorAdapter: - return PixtralProcessorAdapter(self.get_tokenizer()) + def get_hf_processor(self, **kwargs) -> MistralCommonPixtralProcessor: + return self.ctx.init_processor( + MistralCommonPixtralProcessor, + tokenizer=self.get_tokenizer(), + **kwargs, + ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: return {"image": None} - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - processor: PixtralProcessorAdapter, - ) -> int: - ncols, nrows = processor.image_processor._image_to_num_tokens( - Image.new("RGB", (image_width, image_height)) - ) - - return ncols * nrows - def get_image_size_with_most_features(self) -> ImageSize: image_processor = self.get_hf_processor().image_processor - max_image_size = image_processor.mm_config.max_image_size + max_image_size = image_processor.mm_encoder.mm_config.max_image_size return ImageSize(width=max_image_size, height=max_image_size) @@ -321,8 +221,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo]) images = mm_items.get_items("image", ImageProcessorItems) image_size = images.get_image_size(item_idx) - ncols, nrows = processor.image_processor._image_to_num_tokens( - Image.new("RGB", (image_size.width, image_size.height)) + _, nrows, ncols = processor.image_processor.get_number_of_image_patches( + image_size.height, + image_size.width, ) tokens = ([image_token_id] * ncols + [image_break_id]) * nrows diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index 1eb8ecc2d..468944d04 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -14,11 +14,7 @@ from typing import Annotated, Literal, TypeAlias import regex as re import torch from torch import nn -from torchvision import transforms -from torchvision.transforms import InterpolationMode -from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType -from transformers.image_utils import ImageInput -from transformers.tokenization_utils_base import TextInput +from transformers import BatchFeature from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -48,6 +44,7 @@ from vllm.multimodal.processing import ( PromptUpdateDetails, ) from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.processors.qwen_vl import QwenVLProcessor from vllm.utils.tensor_schema import TensorSchema, TensorShape from .interfaces import ( @@ -434,96 +431,16 @@ class QwenVLModel(QWenModel): ) -class QwenVLProcessor: - """ - This model doesn't define its own HF processor, - so we implement our own one here. - - We call the wrapped tokenizer to automatically insert image pad tokens: - https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245 - - The image processor is defined here: - https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354 - """ - - def __init__( - self, - config: PretrainedConfig, - tokenizer: PreTrainedTokenizer, - ) -> None: - super().__init__() - - self.config = config - self.tokenizer = tokenizer - +class QwenVLProcessingInfo(BaseProcessingInfo): + def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor: + config = self.get_hf_config() vision_config = config.visual image_size = vision_config["image_size"] - self.image_transform = transforms.Compose( - [ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC, - ), - transforms.ToTensor(), - transforms.Normalize( - mean=(0.48145466, 0.4578275, 0.40821073), - std=(0.26862954, 0.26130258, 0.27577711), - ), - ] - ) - - @property - def image_start_tag(self) -> str: - return self.tokenizer.image_start_tag # type: ignore - - @property - def image_end_tag(self) -> str: - return self.tokenizer.image_end_tag # type: ignore - - @property - def image_pad_tag(self) -> str: - return self.tokenizer.image_pad_tag # type: ignore - - def __call__( - self, - text: TextInput | list[TextInput] | None = None, - images: ImageInput | list[ImageInput] | None = None, - return_tensors: str | TensorType | None = None, - ) -> BatchFeature: - if text is None: - text = [] - if not isinstance(text, list): - text = [text] - if images is None: - images = [] - if not isinstance(images, list): - images = [images] - - text_inputs = self.tokenizer(text) - - if len(images) == 0: - image_inputs = {} - else: - pixel_values = [self.image_transform(image) for image in images] - image_inputs = {"pixel_values": torch.stack(pixel_values)} - - return BatchFeature( - { - **text_inputs, - **image_inputs, - }, - tensor_type=return_tensors, - ) - - -class QwenVLProcessingInfo(BaseProcessingInfo): - def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor: return self.ctx.init_processor( QwenVLProcessor, - config=self.get_hf_config(), tokenizer=self.get_tokenizer(), - **kwargs, + **{**kwargs, "image_size": image_size}, ) def get_supported_mm_limits(self) -> Mapping[str, int | None]: diff --git a/vllm/model_executor/models/voxtral.py b/vllm/model_executor/models/voxtral.py index 964869a3c..d3eaf284b 100644 --- a/vllm/model_executor/models/voxtral.py +++ b/vllm/model_executor/models/voxtral.py @@ -3,25 +3,19 @@ import math from collections.abc import Iterable, Mapping, Sequence -from functools import cached_property, partial -from math import ceil +from functools import partial from typing import Literal, cast import numpy as np import regex as re import torch import torch.nn as nn -from mistral_common.audio import mel_filter_bank +from mistral_common.audio import Audio, mel_filter_bank from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest from mistral_common.protocol.transcription.request import TranscriptionRequest -from mistral_common.tokens.tokenizers.audio import ( - Audio, - AudioEncoder, -) -from transformers import BatchFeature, TensorType, WhisperConfig -from transformers.tokenization_utils_base import TextInput +from transformers import BatchFeature, WhisperConfig from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig from vllm.config.multimodal import BaseDummyOptions @@ -62,6 +56,7 @@ from vllm.multimodal.processing.processor import ( from vllm.sequence import IntermediateTensors from vllm.tokenizers import cached_tokenizer_from_config from vllm.tokenizers.mistral import MistralTokenizer +from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription from .utils import init_vllm_registered_model, maybe_prefix @@ -81,98 +76,6 @@ ISO639_1_SUPPORTED_LANGS = { } -class VoxtralProcessorAdapter: - """ - Provide a HF-compatible interface for - :class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`. - """ - - def __init__(self, tokenizer: MistralTokenizer) -> None: - super().__init__() - self.tokenizer = tokenizer - - @cached_property - def _audio_processor(self) -> AudioEncoder: - audio_encoder = self.tokenizer.instruct.audio_encoder - assert isinstance(audio_encoder, AudioEncoder) - return audio_encoder - - @cached_property - def audio_token_id(self) -> int: - return self._audio_processor.special_ids.audio - - @cached_property - def begin_audio_token_id(self) -> int: - return self._audio_processor.special_ids.begin_audio - - @cached_property - def sampling_rate(self) -> int: - return self._audio_processor.audio_config.sampling_rate - - @cached_property - def frame_rate(self) -> float: - return self._audio_processor.audio_config.frame_rate - - def get_num_audio_tokens( - self, - audio_length: int, - ) -> int: - return ceil(audio_length / (self.sampling_rate // self.frame_rate)) - - def __call__( - self, - text: TextInput | list[TextInput] | None = None, - audios: np.ndarray | list[np.ndarray] | None = None, - return_tensors: str | TensorType | None = None, - **kwargs, - ) -> Mapping[str, NestedTensors]: - if text is None: - text = [] - if not isinstance(text, list): - text = [text] - if audios is None: - audios = [] - if not isinstance(audios, list): - audios = [audios] - - if not audios: - input_ids = self.tokenizer(text).input_ids - return {"input_ids": torch.tensor(input_ids)} - - # Allow dummy text, which is used for profiling as well as token inputs - if any(len(t) > 0 for t in text): - raise ValueError( - "You've passed text inputs instead of token inputs. " - "Make sure to process your input via `mistral_common`'s " - "tokenizer or pass a chat completion request. " - "For more info, see: " - "https://github.com/vllm-project/vllm/issues/8411." - ) - - audios_tokens = list[torch.Tensor]() - audios_processed = list[torch.Tensor]() - for audio in audios: - assert isinstance(audio, np.ndarray) - assert audio.ndim == 1 - - if not self._audio_processor.audio_config.is_streaming: - audio = self._audio_processor.pad(audio, self.sampling_rate) - - audio_tokens = [self.begin_audio_token_id] + [ - self.audio_token_id - ] * self.get_num_audio_tokens(len(audio)) - - audios_tokens.append(torch.tensor(audio_tokens)) - audios_processed.append(torch.tensor(audio)) - - return BatchFeature( - { - "input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1), - "audio_arrays": audios_processed, - } - ) - - class VoxtralProcessingInfo(BaseProcessingInfo): def get_tokenizer(self) -> MistralTokenizer: tokenizer = cached_tokenizer_from_config(self.ctx.model_config) @@ -181,12 +84,18 @@ class VoxtralProcessingInfo(BaseProcessingInfo): return tokenizer - def get_hf_processor(self) -> VoxtralProcessorAdapter: - return VoxtralProcessorAdapter(self.get_tokenizer()) + def get_hf_processor(self, **kwargs) -> MistralCommonVoxtralProcessor: + return self.ctx.init_processor( + MistralCommonVoxtralProcessor, + tokenizer=self.get_tokenizer(), + **kwargs, + ) def get_data_parser(self): + feature_extractor = self.get_hf_processor().feature_extractor + return MultiModalDataParser( - target_sr=self.get_hf_processor().sampling_rate, + target_sr=feature_extractor.sampling_rate, target_channels=1, expected_hidden_size=self._get_expected_hidden_size(), ) @@ -205,9 +114,10 @@ class VoxtralProcessingInfo(BaseProcessingInfo): return self.ctx.model_config.max_model_len def get_max_audio_array_len(self) -> int: - processor = self.get_hf_processor() + feature_extractor = self.get_hf_processor().feature_extractor + return self.get_max_audio_tokens() * int( - processor.sampling_rate // processor.frame_rate + feature_extractor.sampling_rate // feature_extractor.frame_rate ) @@ -242,6 +152,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): mm_options: Mapping[str, BaseDummyOptions], ) -> ProcessorInputs: tokenizer = self.info.get_tokenizer() + feature_extractor = self.info.get_hf_processor().feature_extractor dummy_text = self.get_dummy_text(mm_counts) dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options) @@ -252,7 +163,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]): for audio in dummy_audios: audio_item = Audio( audio_array=audio, - sampling_rate=self.info.get_hf_processor().sampling_rate, + sampling_rate=feature_extractor.sampling_rate, format=format, ) chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item)) @@ -292,33 +203,26 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) # skip validation here ... - def _apply_hf_processor_mm_only( + def _call_hf_processor( self, - mm_items: MultiModalDataItems, - hf_processor_mm_kwargs: Mapping[str, object], - tokenization_kwargs: Mapping[str, object], + prompt: str, + mm_data: Mapping[str, object], + mm_kwargs: Mapping[str, object], + tok_kwargs: Mapping[str, object], ) -> BatchFeature: - processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) - processor_data, passthrough_data = self._get_hf_mm_data(mm_items) - audios = processor_data.get("audios", []) - if not isinstance(audios, list): - audios = [audios] + mm_data = dict(mm_data) + audios = mm_data.pop("audios", []) - audio_config = processor._audio_processor.audio_config - audio_tensors: list[torch.Tensor] = [] - for audio in audios: - audio = np.asarray(audio, dtype=np.float32).ravel() - if not audio_config.is_streaming: - audio = processor._audio_processor.pad( - audio, - processor.sampling_rate, - audio_config.is_streaming, - ) - audio_tensors.append(torch.tensor(audio)) + if audios: + # MistralCommonVoxtralProcessor accepts "audio" + mm_data["audio"] = audios - result = BatchFeature({"audio_arrays": audio_tensors} if audio_tensors else {}) - result.update(passthrough_data) - return result + return super()._call_hf_processor( + prompt=prompt, + mm_data=mm_data, + mm_kwargs=mm_kwargs, + tok_kwargs=tok_kwargs, + ) def _get_prompt_updates( self, @@ -327,6 +231,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) out_mm_kwargs: MultiModalKwargsItems, ) -> Sequence[PromptUpdate]: processor = self.info.get_hf_processor(**hf_processor_mm_kwargs) + feature_extractor = processor.feature_extractor audio_id = processor.audio_token_id out_mm_data = out_mm_kwargs.require_data() @@ -348,7 +253,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo]) audios = mm_items.get_items("audio", AudioProcessorItems) audio_len = audios.get_audio_length(item_idx) - nb_audio_tokens = processor.get_num_audio_tokens(audio_len) + nb_audio_tokens = feature_extractor.get_num_audio_tokens(audio_len) return [audio_id] * nb_audio_tokens @@ -560,8 +465,8 @@ class VoxtralForConditionalGeneration( This is used for estimating the amount of processing for this audio. """ tokenizer = cached_tokenizer_from_config(model_config) - adapter = VoxtralProcessorAdapter(tokenizer) - return adapter.get_num_audio_tokens( + adapter = MistralCommonVoxtralProcessor(tokenizer) + return adapter.feature_extractor.get_num_audio_tokens( int(audio_duration_s * stt_config.sample_rate) ) diff --git a/vllm/model_executor/models/voxtral_realtime.py b/vllm/model_executor/models/voxtral_realtime.py index 08e583caa..bb2c701e9 100644 --- a/vllm/model_executor/models/voxtral_realtime.py +++ b/vllm/model_executor/models/voxtral_realtime.py @@ -8,12 +8,13 @@ from typing import Literal import numpy as np import torch +from mistral_common.audio import Audio from mistral_common.protocol.instruct.chunk import RawAudio from mistral_common.protocol.transcription.request import ( StreamingMode, TranscriptionRequest, ) -from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig +from mistral_common.tokens.tokenizers.audio import AudioConfig from vllm.compilation.decorators import support_torch_compile from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig diff --git a/vllm/multimodal/processing/context.py b/vllm/multimodal/processing/context.py index 9cf3863fe..98a41f69b 100644 --- a/vllm/multimodal/processing/context.py +++ b/vllm/multimodal/processing/context.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from abc import abstractmethod -from collections.abc import Mapping +from collections.abc import Callable, Mapping from contextlib import contextmanager from dataclasses import dataclass, field from functools import cached_property @@ -241,13 +241,13 @@ class InputProcessingContext: def call_hf_processor( self, - hf_processor: ProcessorMixin, + hf_processor: Callable[..., BatchFeature] | ProcessorMixin, data: Mapping[str, object], kwargs: Mapping[str, object] = {}, *, num_tries: int = 1, max_tries: int = 5, - ) -> BatchFeature | JSONTree: + ) -> BatchFeature: """ Call `hf_processor` on the prompt `data` (text, image, audio...) with configurable options `kwargs`. @@ -300,7 +300,7 @@ class InputProcessingContext: if isinstance(output, BatchFeature): output_ = self._postprocess_output(output.data) - return BatchFeature(output_) + return BatchFeature(output_) # type: ignore logger.warning_once( "%s did not return `BatchFeature`. " @@ -309,7 +309,7 @@ class InputProcessingContext: type(hf_processor).__name__, ) - return self._postprocess_output(output) + return self._postprocess_output(output) # type: ignore class BaseProcessingInfo: diff --git a/vllm/transformers_utils/processor.py b/vllm/transformers_utils/processor.py index 1319e2943..2605a5f84 100644 --- a/vllm/transformers_utils/processor.py +++ b/vllm/transformers_utils/processor.py @@ -241,12 +241,13 @@ def get_processor_kwargs_type( call_kwargs_annotations = call_kwargs.annotation if call_kwargs else None # if the processor has explicit kwargs annotation, use it - if call_kwargs_annotations not in (None, inspect._empty): + if call_kwargs_annotations not in (None, inspect._empty): # noqa: SIM102 # get_type_hints will parse all type annotations at runtime, # and if an annotation refers to a type or # name that hasn’t been imported or defined, it will raise an error. # So we use __annotations__ to get the raw annotations directly. - return get_args(call_kwargs_annotations)[0] + if anno_args := get_args(call_kwargs_annotations): + return anno_args[0] # otherwise, try to get from ProcessorKwargs module_name = type(processor).__module__ @@ -266,7 +267,13 @@ def get_processor_kwargs_keys( kwargs_cls: type[processing_utils.ProcessingKwargs], ) -> set[str]: dynamic_kwargs: set[str] = set() - modality_kwargs = {"text_kwargs", "images_kwargs", "videos_kwargs", "audio_kwargs"} + modality_kwargs = { + "text_kwargs", + "images_kwargs", + "videos_kwargs", + "audio_kwargs", + "common_kwargs", + } try: # get kwargs annotations in processor diff --git a/vllm/transformers_utils/processors/__init__.py b/vllm/transformers_utils/processors/__init__.py index ff2263f3e..50c944e9d 100644 --- a/vllm/transformers_utils/processors/__init__.py +++ b/vllm/transformers_utils/processors/__init__.py @@ -15,10 +15,14 @@ _CLASS_TO_MODULE: dict[str, str] = { "DeepseekVLV2Processor": "vllm.transformers_utils.processors.deepseek_vl2", "FireRedASR2Processor": "vllm.transformers_utils.processors.fireredasr2", "FunASRProcessor": "vllm.transformers_utils.processors.funasr", + "GLM4VProcessor": "vllm.transformers_utils.processors.glm4v", "HunYuanVLProcessor": "vllm.transformers_utils.processors.hunyuan_vl", "HunYuanVLImageProcessor": "vllm.transformers_utils.processors.hunyuan_vl_image", + "MistralCommonPixtralProcessor": "vllm.transformers_utils.processors.pixtral", + "MistralCommonVoxtralProcessor": "vllm.transformers_utils.processors.voxtral", "OvisProcessor": "vllm.transformers_utils.processors.ovis", "Ovis2_5Processor": "vllm.transformers_utils.processors.ovis2_5", + "QwenVLProcessor": "vllm.transformers_utils.processors.qwen_vl", "Qwen3ASRProcessor": "vllm.transformers_utils.processors.qwen3_asr", } @@ -28,10 +32,14 @@ __all__ = [ "DeepseekVLV2Processor", "FireRedASR2Processor", "FunASRProcessor", + "GLM4VProcessor", "HunYuanVLProcessor", "HunYuanVLImageProcessor", + "MistralCommonPixtralProcessor", + "MistralCommonVoxtralProcessor", "OvisProcessor", "Ovis2_5Processor", + "QwenVLProcessor", "Qwen3ASRProcessor", ] diff --git a/vllm/transformers_utils/processors/glm4v.py b/vllm/transformers_utils/processors/glm4v.py new file mode 100644 index 000000000..b08113e04 --- /dev/null +++ b/vllm/transformers_utils/processors/glm4v.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers import PreTrainedTokenizer +from transformers.image_processing_utils_fast import BaseImageProcessorFast +from transformers.image_utils import PILImageResampling +from transformers.processing_utils import ProcessorMixin + + +class GLM4VImageProcessorFast(BaseImageProcessorFast): + """ + Port of https://huggingface.co/zai-org/glm-4v-9b/blob/main/tokenization_chatglm.py#L177 + to HF Transformers. + """ + + resample = PILImageResampling.BICUBIC + image_mean = [0.48145466, 0.4578275, 0.40821073] + image_std = [0.26862954, 0.26130258, 0.27577711] + size = {"height": 1120, "width": 1120} + do_resize = True + do_rescale = True + do_normalize = True + + +class GLM4VProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + + def __init__( + self, + tokenizer: PreTrainedTokenizer, + image_size: int, + ) -> None: + self.tokenizer = tokenizer + self.image_processor = GLM4VImageProcessorFast( + size={"width": image_size, "height": image_size} + ) diff --git a/vllm/transformers_utils/processors/pixtral.py b/vllm/transformers_utils/processors/pixtral.py new file mode 100644 index 000000000..8e9b241e8 --- /dev/null +++ b/vllm/transformers_utils/processors/pixtral.py @@ -0,0 +1,116 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch +from mistral_common.protocol.instruct.chunk import ImageChunk +from mistral_common.tokens.tokenizers.multimodal import ImageEncoder +from PIL import Image +from transformers import BatchFeature, ProcessorMixin, TensorType +from transformers.audio_utils import AudioInput +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.video_utils import VideoInput + +from vllm.tokenizers.mistral import MistralTokenizer + + +class MistralCommonImageProcessor: + """ + Provide a HF-compatible interface for + `mistral_common.tokens.tokenizers.multimodal.ImageEncoder`. + """ + + def __init__(self, mm_encoder: ImageEncoder) -> None: + self.mm_encoder = mm_encoder + + def __call__( + self, + images: ImageInput, + return_tensors: str | TensorType | None = None, + **kwargs, + ) -> BatchFeature: + images_lst = [images] if not isinstance(images, list) else images + + images_processed = list[torch.Tensor]() + + for image in images_lst: + image_inputs = self.mm_encoder(ImageChunk(image=image)) + image_processed = torch.tensor(image_inputs.image) + + images_processed.append(image_processed) + + return BatchFeature({"images": images_processed}, tensor_type=return_tensors) + + def get_number_of_image_patches( + self, + height: int, + width: int, + ) -> tuple[int, int, int]: + image = Image.new("RGB", (width, height)) + ncols, nrows = self.mm_encoder._image_to_num_tokens(image) + return ncols * nrows, nrows, ncols + + +class MistralCommonPixtralProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + + def __init__(self, tokenizer: MistralTokenizer) -> None: + self.tokenizer = tokenizer.transformers_tokenizer + self.image_processor = MistralCommonImageProcessor( + tokenizer.instruct.mm_encoder + ) + + self._image_special_ids = self.image_processor.mm_encoder.special_ids + + @property + def image_break_id(self) -> int: + return self._image_special_ids.img_break + + @property + def image_token_id(self) -> int: + return self._image_special_ids.img + + @property + def image_end_id(self) -> int: + return self._image_special_ids.img_end + + def __call__( + self, + images: ImageInput | None = None, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] + | None = None, + videos: VideoInput | None = None, + audio: AudioInput | None = None, + **kwargs, + ): + if images is None and text is None and videos is None and audio is None: + raise ValueError( + f"You need to provide at least one input to " + f"call {self.__class__.__name__}" + ) + + kwargs = self._merge_kwargs( + self.valid_processor_kwargs, + tokenizer_init_kwargs={}, + **kwargs, + ) + kwargs["text_kwargs"]["return_tensors"] = "pt" + kwargs["images_kwargs"]["return_tensors"] = None # Avoid padding issue + + attribute_to_kwargs = { + "tokenizer": (text, "text_kwargs"), + "image_processor": (images, "images_kwargs"), + "video_processor": (videos, "videos_kwargs"), + "feature_extractor": (audio, "audio_kwargs"), + } + outputs = {} + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name, None) + input_data, input_kwargs = attribute_to_kwargs[attribute_name] + if input_data is not None and attribute is not None: + attribute_output = attribute(input_data, **kwargs[input_kwargs]) + outputs.update(attribute_output) + + return BatchFeature(outputs) diff --git a/vllm/transformers_utils/processors/qwen_vl.py b/vllm/transformers_utils/processors/qwen_vl.py new file mode 100644 index 000000000..d7b4f1c43 --- /dev/null +++ b/vllm/transformers_utils/processors/qwen_vl.py @@ -0,0 +1,48 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from transformers.image_processing_utils_fast import BaseImageProcessorFast +from transformers.image_utils import PILImageResampling +from transformers.processing_utils import ProcessorMixin + +from vllm.tokenizers.qwen_vl import QwenVLTokenizer + + +class QwenVLImageProcessorFast(BaseImageProcessorFast): + """ + Port of https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354 + to HF Transformers. + """ + + resample = PILImageResampling.BICUBIC + image_mean = [0.48145466, 0.4578275, 0.40821073] + image_std = [0.26862954, 0.26130258, 0.27577711] + size = {"height": 448, "width": 448} + do_resize = True + do_rescale = True + do_normalize = True + + +class QwenVLProcessor(ProcessorMixin): + attributes = ["image_processor", "tokenizer"] + + def __init__( + self, + tokenizer: QwenVLTokenizer, + image_size: int, + ) -> None: + self.tokenizer = tokenizer + self.image_processor = QwenVLImageProcessorFast( + size={"width": image_size, "height": image_size} + ) + + @property + def image_start_tag(self) -> str: + return self.tokenizer.image_start_tag # type: ignore[attr-defined] + + @property + def image_end_tag(self) -> str: + return self.tokenizer.image_end_tag # type: ignore[attr-defined] + + @property + def image_pad_tag(self) -> str: + return self.tokenizer.image_pad_tag # type: ignore[attr-defined] diff --git a/vllm/transformers_utils/processors/voxtral.py b/vllm/transformers_utils/processors/voxtral.py new file mode 100644 index 000000000..805853fd9 --- /dev/null +++ b/vllm/transformers_utils/processors/voxtral.py @@ -0,0 +1,119 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from math import ceil + +import numpy as np +import torch +from mistral_common.tokens.tokenizers.audio import AudioEncoder +from transformers import BatchFeature, ProcessorMixin, TensorType +from transformers.audio_utils import AudioInput +from transformers.image_utils import ImageInput +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.video_utils import VideoInput + +from vllm.tokenizers.mistral import MistralTokenizer + + +class MistralCommonFeatureExtractor: + """ + Provide a HF-compatible interface for + `mistral_common.tokens.tokenizers.multimodal.AudioEncoder`. + """ + + def __init__(self, audio_encoder: AudioEncoder) -> None: + self.audio_encoder = audio_encoder + + @property + def sampling_rate(self): + return self.audio_encoder.audio_config.sampling_rate + + @property + def frame_rate(self): + return self.audio_encoder.audio_config.frame_rate + + def __call__( + self, + audios: AudioInput, + return_tensors: str | TensorType | None = None, + **kwargs, + ) -> BatchFeature: + audios_lst = [audios] if not isinstance(audios, list) else audios + + audios_processed = list[torch.Tensor]() + + for audio in audios_lst: + audio = np.asarray(audio, dtype=np.float32).ravel() + if not self.audio_encoder.audio_config.is_streaming: + audio = self.audio_encoder.pad(audio, self.sampling_rate) + + audios_processed.append(torch.tensor(audio)) + + return BatchFeature( + {"audio_arrays": audios_processed}, tensor_type=return_tensors + ) + + def get_num_audio_tokens(self, audio_length: int) -> int: + return ceil(audio_length / (self.sampling_rate // self.frame_rate)) + + +class MistralCommonVoxtralProcessor(ProcessorMixin): + attributes = ["feature_extractor", "tokenizer"] + + def __init__(self, tokenizer: MistralTokenizer) -> None: + self.tokenizer = tokenizer.transformers_tokenizer + self.feature_extractor = MistralCommonFeatureExtractor( + tokenizer.instruct.audio_encoder + ) + + self._audio_special_ids = self.feature_extractor.audio_encoder.special_ids + + @property + def audio_token_id(self) -> int: + return self._audio_special_ids.audio + + @property + def begin_audio_token_id(self) -> int: + return self._audio_special_ids.begin_audio + + def __call__( + self, + images: ImageInput | None = None, + text: TextInput + | PreTokenizedInput + | list[TextInput] + | list[PreTokenizedInput] + | None = None, + videos: VideoInput | None = None, + audio: AudioInput | None = None, + **kwargs, + ): + if images is None and text is None and videos is None and audio is None: + raise ValueError( + f"You need to provide at least one input to " + f"call {self.__class__.__name__}" + ) + + kwargs = self._merge_kwargs( + self.valid_processor_kwargs, + tokenizer_init_kwargs={}, + **kwargs, + ) + kwargs["text_kwargs"]["return_tensors"] = "pt" + kwargs["audio_kwargs"]["return_tensors"] = None # Avoid padding issue + + attribute_to_kwargs = { + "tokenizer": (text, "text_kwargs"), + "image_processor": (images, "images_kwargs"), + "video_processor": (videos, "videos_kwargs"), + "feature_extractor": (audio, "audio_kwargs"), + } + outputs = {} + for attribute_name in self.attributes: + attribute = getattr(self, attribute_name, None) + input_data, input_kwargs = attribute_to_kwargs[attribute_name] + if input_data is not None and attribute is not None: + attribute_output = attribute(input_data, **kwargs[input_kwargs]) + outputs.update(attribute_output) + + return BatchFeature(outputs)