[Misc] Move processors to transformers_utils (#35953)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-03-09 11:31:39 +08:00
committed by GitHub
parent bd2659a566
commit d62856b928
13 changed files with 507 additions and 595 deletions

View File

@@ -13,11 +13,7 @@ import numpy as np
import torch
from torch import nn
from torch.nn import LayerNorm
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -50,7 +46,8 @@ from vllm.multimodal.processing import (
PromptUpdate,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs import ChatGLMConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.processors.glm4v import GLM4VProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
@@ -386,81 +383,19 @@ class GLM4VModel(ChatGLMModel):
)
class GLM4VProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
"""
def __init__(
self,
config: ChatGLMConfig,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
vision_config = config.vision_config
image_size = vision_config["image_size"]
self.image_transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
text_inputs = self.tokenizer(text)
if len(images) == 0:
image_inputs = {}
else:
pixel_values = [self.image_transform(image) for image in images]
image_inputs = {"pixel_values": torch.stack(pixel_values)}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class GLM4VProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(ChatGLMConfig)
def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
config = self.get_hf_config()
vision_config = config.vision_config
image_size = vision_config["image_size"]
return self.ctx.init_processor(
GLM4VProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
**{**kwargs, "image_size": image_size},
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:

View File

@@ -4,7 +4,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from functools import partial
from itertools import islice
from typing import Annotated
@@ -13,9 +13,11 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from transformers import (
BaseImageProcessor,
BatchFeature,
PretrainedConfig,
)
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
@@ -1017,117 +1019,28 @@ def select_tiling(
return candidate_tilings[ix]
class MolmoProcessorWrapper:
"""
Wraps `MolmoProcessor` so that it can be called directly.
def _as_2tuple(x: int | tuple[int, int]) -> tuple[int, int]:
if isinstance(x, int):
return x, x
The original definition can be found here:
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
"""
return x
def __init__(self, processor: ProcessorMixin):
super().__init__()
self.processor = processor
@cached_property
def vocab(self) -> dict[str, int]:
return self.processor.tokenizer.vocab # type: ignore
@cached_property
def max_crops(self) -> int:
image_processor = self.processor.image_processor # type: ignore
max_crops = image_processor.max_crops
assert isinstance(max_crops, int)
return max_crops
@cached_property
def base_image_input_size(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
base_image_input_size = image_processor.base_image_input_size
if isinstance(base_image_input_size, int):
return base_image_input_size, base_image_input_size
return tuple(base_image_input_size)
@cached_property
def image_patch_size(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_patch_size = image_processor.image_patch_size
assert isinstance(image_patch_size, int)
return image_patch_size
@cached_property
def overlap_margins(self) -> tuple[int, int]:
image_processor = self.processor.image_processor # type: ignore
left_margin, right_margin = image_processor.overlap_margins
assert isinstance(left_margin, int)
assert isinstance(right_margin, int)
return left_margin, right_margin
@cached_property
def image_token_length_w(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_token_length_w = image_processor.image_token_length_w
assert isinstance(image_token_length_w, int)
return image_token_length_w
@cached_property
def image_token_length_h(self) -> int:
image_processor = self.processor.image_processor # type: ignore
image_token_length_h = image_processor.image_token_length_h
assert isinstance(image_token_length_h, int)
return image_token_length_h
@property
def message_format(self) -> str | None:
return "role"
@property
def always_start_with_space(self) -> bool:
return True
@cached_property
def image_patch_id(self) -> int:
return self.vocab[IMAGE_PATCH_TOKEN]
@cached_property
def im_col_id(self) -> int:
return self.vocab[IM_COL_TOKEN]
@cached_property
def im_start_id(self) -> int:
return self.vocab[IM_START_TOKEN]
@cached_property
def im_end_id(self) -> int:
return self.vocab[IM_END_TOKEN]
@property
def pooling_size(self) -> int:
return POOLING_SIZE
class MolmoProcessingInfo(BaseProcessingInfo):
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def select_tiling(
self,
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
max_crops = self.max_crops
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
max_crops = image_processor.max_crops
left_margin, right_margin = image_processor.overlap_margins
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
base_image_input_d = image_processor.image_patch_size
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
crop_patches = base_image_input_size[0] // base_image_input_d
@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper:
*,
image_width: int,
image_height: int,
image_processor: BaseImageProcessor,
) -> tuple[int, int]:
left_margin, right_margin = self.overlap_margins
base_image_input_size = self.base_image_input_size
base_image_input_d = self.image_patch_size
pooling_size = self.pooling_size
left_margin, right_margin = image_processor.overlap_margins
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
base_image_input_d = image_processor.image_patch_size
pooling_size = POOLING_SIZE
crop_patches = base_image_input_size[0] // base_image_input_d
tiling_w, tiling_h = self.select_tiling(
image_height=image_height,
image_width=image_width,
image_processor=image_processor,
)
nrows, ncols = get_patches_grid_size(
@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper:
return ncols, nrows
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
outputs = self.processor.process( # type: ignore
text, images, **kwargs
)
if images is None:
images = []
if not isinstance(images, list):
images = [images]
input_ids: torch.Tensor = outputs.pop("input_ids")
outputs["input_ids"] = input_ids.unsqueeze(0)
image_input_idx = outputs.pop("image_input_idx", None)
if image_input_idx is not None:
feat_is_patch = image_input_idx >= 0
tilings = [
self.select_tiling(
image_width=image.size[0],
image_height=image.size[1],
)
for image in images
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
outputs["image_input_idx"] = image_input_idx
outputs["num_crops"] = num_crops
outputs["img_patch_id"] = self.image_patch_id
return BatchFeature(outputs)
class MolmoProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
processor = self.ctx.get_hf_processor(**kwargs)
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: MolmoProcessorWrapper,
image_processor: BaseImageProcessor,
) -> int:
ncols, nrows = processor.get_patches_grid_size(
ncols, nrows = self.get_patches_grid_size(
image_width=image_width,
image_height=image_height,
image_processor=image_processor,
)
pooling_size = processor.pooling_size
pooling_size = POOLING_SIZE
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
image_token_length_w = image_processor.image_token_length_w
image_token_length_h = image_processor.image_token_length_h
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
extra = 2 + (image_token_length_w + 1) * image_token_length_h
@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo):
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
image_processor = processor.image_processor
tilings = get_candidate_tilings(processor.max_crops)
base_h, base_w = processor.base_image_input_size
tilings = get_candidate_tilings(image_processor.max_crops)
base_h, base_w = _as_2tuple(image_processor.base_image_input_size)
largest_feature_size, largest_feature_pinpoint = 0, None
for wr, hr in tilings:
@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
feat_size = self.get_num_image_tokens(
image_width=width,
image_height=height,
processor=processor,
image_processor=image_processor,
)
if feat_size > largest_feature_size:
largest_feature_size = feat_size
@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
hf_processor = self.info.get_hf_processor(**mm_kwargs)
processed_outputs = self.info.ctx.call_hf_processor(
hf_processor.process,
dict(text=prompt, **mm_data),
dict(**mm_kwargs, **tok_kwargs),
)
tokenizer = hf_processor.tokenizer
image_patch_id = tokenizer.vocab[IMAGE_PATCH_TOKEN]
image_processor = hf_processor.image_processor
input_ids: torch.Tensor = processed_outputs.pop("input_ids")
processed_outputs["input_ids"] = input_ids.unsqueeze(0)
if (images := mm_data.get("images")) is not None:
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
parsed_images = mm_items.get_items("image", ImageProcessorItems)
image_sizes = [
parsed_images.get_image_size(i) for i in range(len(parsed_images))
]
feat_is_patch = processed_outputs["image_input_idx"] >= 0
tilings = [
self.info.select_tiling(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
for image_size in image_sizes
]
# For each image: tiling_h * tiling_w + extra
num_crops = torch.tensor(tilings).prod(-1) + 1
assert num_crops.sum() == len(feat_is_patch)
processed_outputs["num_crops"] = num_crops
processed_outputs["img_patch_id"] = image_patch_id
return processed_outputs
def _apply_hf_processor_tokens_only(
self,
prompt_tokens: list[int],
@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
# The chat template is already applied to the prompt tokens
# Use message_format="none" to avoid applying it again
# Prepend an empty space if `always_start_with_space` is True
tokens = processor.processor.get_tokens_input( # type: ignore
tokens = processor.get_tokens_input(
self.info.get_tokenizer().decode(prompt_tokens),
message_format="none",
always_start_with_space=processor.always_start_with_space,
always_start_with_space=True,
)
# Prepend a BOS token id to the tokens
processed_data = self.info.ctx.call_hf_processor(
processor, # type: ignore
processor.process,
dict(tokens=tokens),
)
(prompt_ids,) = processed_data.pop("input_ids").tolist()
prompt_ids = processed_data.pop("input_ids").tolist()
print(prompt_ids, len(prompt_ids))
return prompt_ids
@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
img_patch_id = vocab[IMAGE_PATCH_TOKEN]
img_col_id = vocab[IM_COL_TOKEN]
img_start_id = vocab[IM_START_TOKEN]
img_end_id = vocab[IM_END_TOKEN]
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
image_token_length_w = processor.image_token_length_w
image_token_length_h = processor.image_token_length_h
pooling_size = processor.pooling_size
img_patch_id = processor.image_patch_id
img_col_id = processor.im_col_id
img_start_id = processor.im_start_id
img_end_id = processor.im_end_id
image_processor = processor.image_processor
image_token_length_w = image_processor.image_token_length_w
image_token_length_h = image_processor.image_token_length_h
pooling_size = POOLING_SIZE
extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = processor.get_patches_grid_size(
ncols, nrows = self.info.get_patches_grid_size(
image_width=image_size.width,
image_height=image_size.height,
image_processor=image_processor,
)
joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]

View File

@@ -4,7 +4,6 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import Annotated, Literal
import torch
@@ -13,10 +12,7 @@ import torch.nn.functional as F
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
from PIL import Image
from transformers import BatchFeature, PixtralVisionConfig, TensorType
from transformers.image_utils import ImageInput
from transformers import PixtralVisionConfig
from transformers.models.pixtral.image_processing_pixtral import (
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
)
@@ -25,7 +21,6 @@ from transformers.models.pixtral.modeling_pixtral import (
apply_rotary_pos_emb,
position_ids_in_meshgrid,
)
from transformers.tokenization_utils_base import TextInput
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -66,6 +61,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
@@ -121,93 +117,6 @@ class PixtralImagePixelInputs(TensorSchema):
]
class PixtralProcessorAdapter:
"""
Provide a HF-compatible interface for
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
@property
def image_processor(self) -> ImageEncoder:
image_encoder = self.tokenizer.instruct.mm_encoder
assert isinstance(image_encoder, ImageEncoder)
return image_encoder
@cached_property
def image_break_id(self) -> int:
return self.image_processor.special_ids.img_break
@cached_property
def image_token_id(self) -> int:
return self.image_processor.special_ids.img
@cached_property
def image_end_id(self) -> int:
return self.image_processor.special_ids.img_end
@cached_property
def image_size(self) -> int:
return self.image_processor.mm_config.max_image_size
@cached_property
def patch_size(self) -> int:
return self.image_processor.mm_config.image_patch_size
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
if not images:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
images_processed = list[torch.Tensor]()
images_tokens = list[torch.Tensor]()
for image in images:
image_inputs = self.image_processor(ImageChunk(image=image))
image_processed = torch.tensor(image_inputs.image)
image_tokens = torch.tensor(image_inputs.tokens)
images_processed.append(image_processed)
images_tokens.append(image_tokens)
return BatchFeature(
{
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
"images": images_processed,
}
)
class PixtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
@@ -216,28 +125,19 @@ class PixtralProcessingInfo(BaseProcessingInfo):
return tokenizer
def get_hf_processor(self) -> PixtralProcessorAdapter:
return PixtralProcessorAdapter(self.get_tokenizer())
def get_hf_processor(self, **kwargs) -> MistralCommonPixtralProcessor:
return self.ctx.init_processor(
MistralCommonPixtralProcessor,
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: PixtralProcessorAdapter,
) -> int:
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_width, image_height))
)
return ncols * nrows
def get_image_size_with_most_features(self) -> ImageSize:
image_processor = self.get_hf_processor().image_processor
max_image_size = image_processor.mm_config.max_image_size
max_image_size = image_processor.mm_encoder.mm_config.max_image_size
return ImageSize(width=max_image_size, height=max_image_size)
@@ -321,8 +221,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
images = mm_items.get_items("image", ImageProcessorItems)
image_size = images.get_image_size(item_idx)
ncols, nrows = processor.image_processor._image_to_num_tokens(
Image.new("RGB", (image_size.width, image_size.height))
_, nrows, ncols = processor.image_processor.get_number_of_image_patches(
image_size.height,
image_size.width,
)
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows

View File

@@ -14,11 +14,7 @@ from typing import Annotated, Literal, TypeAlias
import regex as re
import torch
from torch import nn
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from transformers import BatchFeature
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -48,6 +44,7 @@ from vllm.multimodal.processing import (
PromptUpdateDetails,
)
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.qwen_vl import QwenVLProcessor
from vllm.utils.tensor_schema import TensorSchema, TensorShape
from .interfaces import (
@@ -434,96 +431,16 @@ class QwenVLModel(QWenModel):
)
class QwenVLProcessor:
"""
This model doesn't define its own HF processor,
so we implement our own one here.
We call the wrapped tokenizer to automatically insert image pad tokens:
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
The image processor is defined here:
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
"""
def __init__(
self,
config: PretrainedConfig,
tokenizer: PreTrainedTokenizer,
) -> None:
super().__init__()
self.config = config
self.tokenizer = tokenizer
class QwenVLProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
config = self.get_hf_config()
vision_config = config.visual
image_size = vision_config["image_size"]
self.image_transform = transforms.Compose(
[
transforms.Resize(
(image_size, image_size),
interpolation=InterpolationMode.BICUBIC,
),
transforms.ToTensor(),
transforms.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
),
]
)
@property
def image_start_tag(self) -> str:
return self.tokenizer.image_start_tag # type: ignore
@property
def image_end_tag(self) -> str:
return self.tokenizer.image_end_tag # type: ignore
@property
def image_pad_tag(self) -> str:
return self.tokenizer.image_pad_tag # type: ignore
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if images is None:
images = []
if not isinstance(images, list):
images = [images]
text_inputs = self.tokenizer(text)
if len(images) == 0:
image_inputs = {}
else:
pixel_values = [self.image_transform(image) for image in images]
image_inputs = {"pixel_values": torch.stack(pixel_values)}
return BatchFeature(
{
**text_inputs,
**image_inputs,
},
tensor_type=return_tensors,
)
class QwenVLProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
return self.ctx.init_processor(
QwenVLProcessor,
config=self.get_hf_config(),
tokenizer=self.get_tokenizer(),
**kwargs,
**{**kwargs, "image_size": image_size},
)
def get_supported_mm_limits(self) -> Mapping[str, int | None]:

View File

@@ -3,25 +3,19 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property, partial
from math import ceil
from functools import partial
from typing import Literal, cast
import numpy as np
import regex as re
import torch
import torch.nn as nn
from mistral_common.audio import mel_filter_bank
from mistral_common.audio import Audio, mel_filter_bank
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
from mistral_common.protocol.instruct.messages import UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
from mistral_common.protocol.transcription.request import TranscriptionRequest
from mistral_common.tokens.tokenizers.audio import (
Audio,
AudioEncoder,
)
from transformers import BatchFeature, TensorType, WhisperConfig
from transformers.tokenization_utils_base import TextInput
from transformers import BatchFeature, WhisperConfig
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -62,6 +56,7 @@ from vllm.multimodal.processing.processor import (
from vllm.sequence import IntermediateTensors
from vllm.tokenizers import cached_tokenizer_from_config
from vllm.tokenizers.mistral import MistralTokenizer
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
from .utils import init_vllm_registered_model, maybe_prefix
@@ -81,98 +76,6 @@ ISO639_1_SUPPORTED_LANGS = {
}
class VoxtralProcessorAdapter:
"""
Provide a HF-compatible interface for
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
"""
def __init__(self, tokenizer: MistralTokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
@cached_property
def _audio_processor(self) -> AudioEncoder:
audio_encoder = self.tokenizer.instruct.audio_encoder
assert isinstance(audio_encoder, AudioEncoder)
return audio_encoder
@cached_property
def audio_token_id(self) -> int:
return self._audio_processor.special_ids.audio
@cached_property
def begin_audio_token_id(self) -> int:
return self._audio_processor.special_ids.begin_audio
@cached_property
def sampling_rate(self) -> int:
return self._audio_processor.audio_config.sampling_rate
@cached_property
def frame_rate(self) -> float:
return self._audio_processor.audio_config.frame_rate
def get_num_audio_tokens(
self,
audio_length: int,
) -> int:
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
def __call__(
self,
text: TextInput | list[TextInput] | None = None,
audios: np.ndarray | list[np.ndarray] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
text = []
if not isinstance(text, list):
text = [text]
if audios is None:
audios = []
if not isinstance(audios, list):
audios = [audios]
if not audios:
input_ids = self.tokenizer(text).input_ids
return {"input_ids": torch.tensor(input_ids)}
# Allow dummy text, which is used for profiling as well as token inputs
if any(len(t) > 0 for t in text):
raise ValueError(
"You've passed text inputs instead of token inputs. "
"Make sure to process your input via `mistral_common`'s "
"tokenizer or pass a chat completion request. "
"For more info, see: "
"https://github.com/vllm-project/vllm/issues/8411."
)
audios_tokens = list[torch.Tensor]()
audios_processed = list[torch.Tensor]()
for audio in audios:
assert isinstance(audio, np.ndarray)
assert audio.ndim == 1
if not self._audio_processor.audio_config.is_streaming:
audio = self._audio_processor.pad(audio, self.sampling_rate)
audio_tokens = [self.begin_audio_token_id] + [
self.audio_token_id
] * self.get_num_audio_tokens(len(audio))
audios_tokens.append(torch.tensor(audio_tokens))
audios_processed.append(torch.tensor(audio))
return BatchFeature(
{
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
"audio_arrays": audios_processed,
}
)
class VoxtralProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self) -> MistralTokenizer:
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
@@ -181,12 +84,18 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return tokenizer
def get_hf_processor(self) -> VoxtralProcessorAdapter:
return VoxtralProcessorAdapter(self.get_tokenizer())
def get_hf_processor(self, **kwargs) -> MistralCommonVoxtralProcessor:
return self.ctx.init_processor(
MistralCommonVoxtralProcessor,
tokenizer=self.get_tokenizer(),
**kwargs,
)
def get_data_parser(self):
feature_extractor = self.get_hf_processor().feature_extractor
return MultiModalDataParser(
target_sr=self.get_hf_processor().sampling_rate,
target_sr=feature_extractor.sampling_rate,
target_channels=1,
expected_hidden_size=self._get_expected_hidden_size(),
)
@@ -205,9 +114,10 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
return self.ctx.model_config.max_model_len
def get_max_audio_array_len(self) -> int:
processor = self.get_hf_processor()
feature_extractor = self.get_hf_processor().feature_extractor
return self.get_max_audio_tokens() * int(
processor.sampling_rate // processor.frame_rate
feature_extractor.sampling_rate // feature_extractor.frame_rate
)
@@ -242,6 +152,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
mm_options: Mapping[str, BaseDummyOptions],
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
feature_extractor = self.info.get_hf_processor().feature_extractor
dummy_text = self.get_dummy_text(mm_counts)
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
@@ -252,7 +163,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
for audio in dummy_audios:
audio_item = Audio(
audio_array=audio,
sampling_rate=self.info.get_hf_processor().sampling_rate,
sampling_rate=feature_extractor.sampling_rate,
format=format,
)
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
@@ -292,33 +203,26 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
# skip validation here
...
def _apply_hf_processor_mm_only(
def _call_hf_processor(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
tok_kwargs: Mapping[str, object],
) -> BatchFeature:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
audios = processor_data.get("audios", [])
if not isinstance(audios, list):
audios = [audios]
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
audio_config = processor._audio_processor.audio_config
audio_tensors: list[torch.Tensor] = []
for audio in audios:
audio = np.asarray(audio, dtype=np.float32).ravel()
if not audio_config.is_streaming:
audio = processor._audio_processor.pad(
audio,
processor.sampling_rate,
audio_config.is_streaming,
)
audio_tensors.append(torch.tensor(audio))
if audios:
# MistralCommonVoxtralProcessor accepts "audio"
mm_data["audio"] = audios
result = BatchFeature({"audio_arrays": audio_tensors} if audio_tensors else {})
result.update(passthrough_data)
return result
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
tok_kwargs=tok_kwargs,
)
def _get_prompt_updates(
self,
@@ -327,6 +231,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
out_mm_kwargs: MultiModalKwargsItems,
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
feature_extractor = processor.feature_extractor
audio_id = processor.audio_token_id
out_mm_data = out_mm_kwargs.require_data()
@@ -348,7 +253,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
audios = mm_items.get_items("audio", AudioProcessorItems)
audio_len = audios.get_audio_length(item_idx)
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
nb_audio_tokens = feature_extractor.get_num_audio_tokens(audio_len)
return [audio_id] * nb_audio_tokens
@@ -560,8 +465,8 @@ class VoxtralForConditionalGeneration(
This is used for estimating the amount of processing for this audio.
"""
tokenizer = cached_tokenizer_from_config(model_config)
adapter = VoxtralProcessorAdapter(tokenizer)
return adapter.get_num_audio_tokens(
adapter = MistralCommonVoxtralProcessor(tokenizer)
return adapter.feature_extractor.get_num_audio_tokens(
int(audio_duration_s * stt_config.sample_rate)
)

View File

@@ -8,12 +8,13 @@ from typing import Literal
import numpy as np
import torch
from mistral_common.audio import Audio
from mistral_common.protocol.instruct.chunk import RawAudio
from mistral_common.protocol.transcription.request import (
StreamingMode,
TranscriptionRequest,
)
from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
from mistral_common.tokens.tokenizers.audio import AudioConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig