[Misc] Move processors to transformers_utils (#35953)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -13,11 +13,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import LayerNorm
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import BatchFeature, PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@@ -50,7 +46,8 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdate,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig
|
||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||
from vllm.transformers_utils.processors.glm4v import GLM4VProcessor
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .chatglm import ChatGLMBaseModel, ChatGLMModel, GLMTransformer
|
||||
@@ -386,81 +383,19 @@ class GLM4VModel(ChatGLMModel):
|
||||
)
|
||||
|
||||
|
||||
class GLM4VProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: ChatGLMConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
vision_config = config.vision_config
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | list[ImageInput] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class GLM4VProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config(ChatGLMConfig)
|
||||
|
||||
def get_hf_processor(self, **kwargs: object) -> GLM4VProcessor:
|
||||
config = self.get_hf_config()
|
||||
vision_config = config.vision_config
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
return self.ctx.init_processor(
|
||||
GLM4VProcessor,
|
||||
config=self.get_hf_config(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**kwargs,
|
||||
**{**kwargs, "image_size": image_size},
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from functools import partial
|
||||
from itertools import islice
|
||||
from typing import Annotated
|
||||
|
||||
@@ -13,9 +13,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
from transformers import (
|
||||
BaseImageProcessor,
|
||||
BatchFeature,
|
||||
PretrainedConfig,
|
||||
)
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
@@ -1017,117 +1019,28 @@ def select_tiling(
|
||||
return candidate_tilings[ix]
|
||||
|
||||
|
||||
class MolmoProcessorWrapper:
|
||||
"""
|
||||
Wraps `MolmoProcessor` so that it can be called directly.
|
||||
def _as_2tuple(x: int | tuple[int, int]) -> tuple[int, int]:
|
||||
if isinstance(x, int):
|
||||
return x, x
|
||||
|
||||
The original definition can be found here:
|
||||
https://huggingface.co/allenai/Molmo-7B-D-0924/blob/main/preprocessing_molmo.py
|
||||
"""
|
||||
return x
|
||||
|
||||
def __init__(self, processor: ProcessorMixin):
|
||||
super().__init__()
|
||||
|
||||
self.processor = processor
|
||||
|
||||
@cached_property
|
||||
def vocab(self) -> dict[str, int]:
|
||||
return self.processor.tokenizer.vocab # type: ignore
|
||||
|
||||
@cached_property
|
||||
def max_crops(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
max_crops = image_processor.max_crops
|
||||
assert isinstance(max_crops, int)
|
||||
|
||||
return max_crops
|
||||
|
||||
@cached_property
|
||||
def base_image_input_size(self) -> tuple[int, int]:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
base_image_input_size = image_processor.base_image_input_size
|
||||
if isinstance(base_image_input_size, int):
|
||||
return base_image_input_size, base_image_input_size
|
||||
|
||||
return tuple(base_image_input_size)
|
||||
|
||||
@cached_property
|
||||
def image_patch_size(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_patch_size = image_processor.image_patch_size
|
||||
assert isinstance(image_patch_size, int)
|
||||
|
||||
return image_patch_size
|
||||
|
||||
@cached_property
|
||||
def overlap_margins(self) -> tuple[int, int]:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
assert isinstance(left_margin, int)
|
||||
assert isinstance(right_margin, int)
|
||||
|
||||
return left_margin, right_margin
|
||||
|
||||
@cached_property
|
||||
def image_token_length_w(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_token_length_w = image_processor.image_token_length_w
|
||||
assert isinstance(image_token_length_w, int)
|
||||
|
||||
return image_token_length_w
|
||||
|
||||
@cached_property
|
||||
def image_token_length_h(self) -> int:
|
||||
image_processor = self.processor.image_processor # type: ignore
|
||||
|
||||
image_token_length_h = image_processor.image_token_length_h
|
||||
assert isinstance(image_token_length_h, int)
|
||||
|
||||
return image_token_length_h
|
||||
|
||||
@property
|
||||
def message_format(self) -> str | None:
|
||||
return "role"
|
||||
|
||||
@property
|
||||
def always_start_with_space(self) -> bool:
|
||||
return True
|
||||
|
||||
@cached_property
|
||||
def image_patch_id(self) -> int:
|
||||
return self.vocab[IMAGE_PATCH_TOKEN]
|
||||
|
||||
@cached_property
|
||||
def im_col_id(self) -> int:
|
||||
return self.vocab[IM_COL_TOKEN]
|
||||
|
||||
@cached_property
|
||||
def im_start_id(self) -> int:
|
||||
return self.vocab[IM_START_TOKEN]
|
||||
|
||||
@cached_property
|
||||
def im_end_id(self) -> int:
|
||||
return self.vocab[IM_END_TOKEN]
|
||||
|
||||
@property
|
||||
def pooling_size(self) -> int:
|
||||
return POOLING_SIZE
|
||||
class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def select_tiling(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> tuple[int, int]:
|
||||
max_crops = self.max_crops
|
||||
left_margin, right_margin = self.overlap_margins
|
||||
base_image_input_size = self.base_image_input_size
|
||||
base_image_input_d = self.image_patch_size
|
||||
max_crops = image_processor.max_crops
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
|
||||
base_image_input_d = image_processor.image_patch_size
|
||||
|
||||
total_margin_pixels = base_image_input_d * (right_margin + left_margin)
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
@@ -1147,16 +1060,18 @@ class MolmoProcessorWrapper:
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> tuple[int, int]:
|
||||
left_margin, right_margin = self.overlap_margins
|
||||
base_image_input_size = self.base_image_input_size
|
||||
base_image_input_d = self.image_patch_size
|
||||
pooling_size = self.pooling_size
|
||||
left_margin, right_margin = image_processor.overlap_margins
|
||||
base_image_input_size = _as_2tuple(image_processor.base_image_input_size)
|
||||
base_image_input_d = image_processor.image_patch_size
|
||||
pooling_size = POOLING_SIZE
|
||||
|
||||
crop_patches = base_image_input_size[0] // base_image_input_d
|
||||
tiling_w, tiling_h = self.select_tiling(
|
||||
image_height=image_height,
|
||||
image_width=image_width,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
nrows, ncols = get_patches_grid_size(
|
||||
@@ -1170,70 +1085,22 @@ class MolmoProcessorWrapper:
|
||||
|
||||
return ncols, nrows
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | list[ImageInput] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
outputs = self.processor.process( # type: ignore
|
||||
text, images, **kwargs
|
||||
)
|
||||
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
input_ids: torch.Tensor = outputs.pop("input_ids")
|
||||
outputs["input_ids"] = input_ids.unsqueeze(0)
|
||||
|
||||
image_input_idx = outputs.pop("image_input_idx", None)
|
||||
if image_input_idx is not None:
|
||||
feat_is_patch = image_input_idx >= 0
|
||||
|
||||
tilings = [
|
||||
self.select_tiling(
|
||||
image_width=image.size[0],
|
||||
image_height=image.size[1],
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
# For each image: tiling_h * tiling_w + extra
|
||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||
assert num_crops.sum() == len(feat_is_patch)
|
||||
|
||||
outputs["image_input_idx"] = image_input_idx
|
||||
outputs["num_crops"] = num_crops
|
||||
outputs["img_patch_id"] = self.image_patch_id
|
||||
|
||||
return BatchFeature(outputs)
|
||||
|
||||
|
||||
class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object) -> MolmoProcessorWrapper:
|
||||
processor = self.ctx.get_hf_processor(**kwargs)
|
||||
return MolmoProcessorWrapper(processor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: MolmoProcessorWrapper,
|
||||
image_processor: BaseImageProcessor,
|
||||
) -> int:
|
||||
ncols, nrows = processor.get_patches_grid_size(
|
||||
ncols, nrows = self.get_patches_grid_size(
|
||||
image_width=image_width,
|
||||
image_height=image_height,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
pooling_size = processor.pooling_size
|
||||
pooling_size = POOLING_SIZE
|
||||
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
image_token_length_w = image_processor.image_token_length_w
|
||||
image_token_length_h = image_processor.image_token_length_h
|
||||
|
||||
# Calculate total tokens: 2 for start/end + (w+1)*h for column separators
|
||||
extra = 2 + (image_token_length_w + 1) * image_token_length_h
|
||||
@@ -1243,9 +1110,10 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
processor = self.get_hf_processor()
|
||||
image_processor = processor.image_processor
|
||||
|
||||
tilings = get_candidate_tilings(processor.max_crops)
|
||||
base_h, base_w = processor.base_image_input_size
|
||||
tilings = get_candidate_tilings(image_processor.max_crops)
|
||||
base_h, base_w = _as_2tuple(image_processor.base_image_input_size)
|
||||
|
||||
largest_feature_size, largest_feature_pinpoint = 0, None
|
||||
for wr, hr in tilings:
|
||||
@@ -1254,7 +1122,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
feat_size = self.get_num_image_tokens(
|
||||
image_width=width,
|
||||
image_height=height,
|
||||
processor=processor,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
if feat_size > largest_feature_size:
|
||||
largest_feature_size = feat_size
|
||||
@@ -1292,6 +1160,54 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
|
||||
|
||||
|
||||
class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
||||
processed_outputs = self.info.ctx.call_hf_processor(
|
||||
hf_processor.process,
|
||||
dict(text=prompt, **mm_data),
|
||||
dict(**mm_kwargs, **tok_kwargs),
|
||||
)
|
||||
|
||||
tokenizer = hf_processor.tokenizer
|
||||
image_patch_id = tokenizer.vocab[IMAGE_PATCH_TOKEN]
|
||||
|
||||
image_processor = hf_processor.image_processor
|
||||
|
||||
input_ids: torch.Tensor = processed_outputs.pop("input_ids")
|
||||
processed_outputs["input_ids"] = input_ids.unsqueeze(0)
|
||||
|
||||
if (images := mm_data.get("images")) is not None:
|
||||
mm_items = self.info.parse_mm_data({"image": images}, validate=False)
|
||||
parsed_images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_sizes = [
|
||||
parsed_images.get_image_size(i) for i in range(len(parsed_images))
|
||||
]
|
||||
|
||||
feat_is_patch = processed_outputs["image_input_idx"] >= 0
|
||||
|
||||
tilings = [
|
||||
self.info.select_tiling(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
for image_size in image_sizes
|
||||
]
|
||||
# For each image: tiling_h * tiling_w + extra
|
||||
num_crops = torch.tensor(tilings).prod(-1) + 1
|
||||
assert num_crops.sum() == len(feat_is_patch)
|
||||
|
||||
processed_outputs["num_crops"] = num_crops
|
||||
processed_outputs["img_patch_id"] = image_patch_id
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _apply_hf_processor_tokens_only(
|
||||
self,
|
||||
prompt_tokens: list[int],
|
||||
@@ -1301,18 +1217,19 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
# The chat template is already applied to the prompt tokens
|
||||
# Use message_format="none" to avoid applying it again
|
||||
# Prepend an empty space if `always_start_with_space` is True
|
||||
tokens = processor.processor.get_tokens_input( # type: ignore
|
||||
tokens = processor.get_tokens_input(
|
||||
self.info.get_tokenizer().decode(prompt_tokens),
|
||||
message_format="none",
|
||||
always_start_with_space=processor.always_start_with_space,
|
||||
always_start_with_space=True,
|
||||
)
|
||||
|
||||
# Prepend a BOS token id to the tokens
|
||||
processed_data = self.info.ctx.call_hf_processor(
|
||||
processor, # type: ignore
|
||||
processor.process,
|
||||
dict(tokens=tokens),
|
||||
)
|
||||
(prompt_ids,) = processed_data.pop("input_ids").tolist()
|
||||
prompt_ids = processed_data.pop("input_ids").tolist()
|
||||
print(prompt_ids, len(prompt_ids))
|
||||
|
||||
return prompt_ids
|
||||
|
||||
@@ -1338,16 +1255,18 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
vocab = tokenizer.get_vocab()
|
||||
img_patch_id = vocab[IMAGE_PATCH_TOKEN]
|
||||
img_col_id = vocab[IM_COL_TOKEN]
|
||||
img_start_id = vocab[IM_START_TOKEN]
|
||||
img_end_id = vocab[IM_END_TOKEN]
|
||||
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
|
||||
image_token_length_w = processor.image_token_length_w
|
||||
image_token_length_h = processor.image_token_length_h
|
||||
pooling_size = processor.pooling_size
|
||||
|
||||
img_patch_id = processor.image_patch_id
|
||||
img_col_id = processor.im_col_id
|
||||
img_start_id = processor.im_start_id
|
||||
img_end_id = processor.im_end_id
|
||||
image_processor = processor.image_processor
|
||||
image_token_length_w = image_processor.image_token_length_w
|
||||
image_token_length_h = image_processor.image_token_length_h
|
||||
pooling_size = POOLING_SIZE
|
||||
|
||||
extra_row = [img_patch_id] * image_token_length_w + [img_col_id]
|
||||
extra_joint = [img_start_id] + extra_row * image_token_length_h + [img_end_id]
|
||||
@@ -1356,9 +1275,10 @@ class MolmoMultiModalProcessor(BaseMultiModalProcessor[MolmoProcessingInfo]):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = processor.get_patches_grid_size(
|
||||
ncols, nrows = self.info.get_patches_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
|
||||
joint_row = [img_patch_id] * ((ncols + 1) // pooling_size) + [img_col_id]
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
@@ -13,10 +12,7 @@ import torch.nn.functional as F
|
||||
from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.multimodal import ImageEncoder
|
||||
from PIL import Image
|
||||
from transformers import BatchFeature, PixtralVisionConfig, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers import PixtralVisionConfig
|
||||
from transformers.models.pixtral.image_processing_pixtral import (
|
||||
_num_image_tokens as _get_pixtral_hf_num_image_tokens,
|
||||
)
|
||||
@@ -25,7 +21,6 @@ from transformers.models.pixtral.modeling_pixtral import (
|
||||
apply_rotary_pos_emb,
|
||||
position_ids_in_meshgrid,
|
||||
)
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@@ -66,6 +61,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.transformers_utils.processors.pixtral import MistralCommonPixtralProcessor
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
@@ -121,93 +117,6 @@ class PixtralImagePixelInputs(TensorSchema):
|
||||
]
|
||||
|
||||
|
||||
class PixtralProcessorAdapter:
|
||||
"""
|
||||
Provide a HF-compatible interface for
|
||||
`mistral_common.tokens.tokenizers.multimodal.ImageEncoder`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def image_processor(self) -> ImageEncoder:
|
||||
image_encoder = self.tokenizer.instruct.mm_encoder
|
||||
assert isinstance(image_encoder, ImageEncoder)
|
||||
return image_encoder
|
||||
|
||||
@cached_property
|
||||
def image_break_id(self) -> int:
|
||||
return self.image_processor.special_ids.img_break
|
||||
|
||||
@cached_property
|
||||
def image_token_id(self) -> int:
|
||||
return self.image_processor.special_ids.img
|
||||
|
||||
@cached_property
|
||||
def image_end_id(self) -> int:
|
||||
return self.image_processor.special_ids.img_end
|
||||
|
||||
@cached_property
|
||||
def image_size(self) -> int:
|
||||
return self.image_processor.mm_config.max_image_size
|
||||
|
||||
@cached_property
|
||||
def patch_size(self) -> int:
|
||||
return self.image_processor.mm_config.image_patch_size
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | list[ImageInput] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
if not images:
|
||||
input_ids = self.tokenizer(text).input_ids
|
||||
|
||||
return {"input_ids": torch.tensor(input_ids)}
|
||||
|
||||
# Allow dummy text, which is used for profiling as well as token inputs
|
||||
if any(len(t) > 0 for t in text):
|
||||
raise ValueError(
|
||||
"You've passed text inputs instead of token inputs. "
|
||||
"Make sure to process your input via `mistral_common`'s "
|
||||
"tokenizer or pass a chat completion request. "
|
||||
"For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411."
|
||||
)
|
||||
|
||||
images_processed = list[torch.Tensor]()
|
||||
images_tokens = list[torch.Tensor]()
|
||||
|
||||
for image in images:
|
||||
image_inputs = self.image_processor(ImageChunk(image=image))
|
||||
image_processed = torch.tensor(image_inputs.image)
|
||||
image_tokens = torch.tensor(image_inputs.tokens)
|
||||
|
||||
images_processed.append(image_processed)
|
||||
images_tokens.append(image_tokens)
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
"input_ids": torch.cat(images_tokens)[None].expand(len(text), -1),
|
||||
"images": images_processed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
def get_tokenizer(self) -> MistralTokenizer:
|
||||
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
||||
@@ -216,28 +125,19 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_hf_processor(self) -> PixtralProcessorAdapter:
|
||||
return PixtralProcessorAdapter(self.get_tokenizer())
|
||||
def get_hf_processor(self, **kwargs) -> MistralCommonPixtralProcessor:
|
||||
return self.ctx.init_processor(
|
||||
MistralCommonPixtralProcessor,
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: PixtralProcessorAdapter,
|
||||
) -> int:
|
||||
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||
Image.new("RGB", (image_width, image_height))
|
||||
)
|
||||
|
||||
return ncols * nrows
|
||||
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
image_processor = self.get_hf_processor().image_processor
|
||||
max_image_size = image_processor.mm_config.max_image_size
|
||||
max_image_size = image_processor.mm_encoder.mm_config.max_image_size
|
||||
|
||||
return ImageSize(width=max_image_size, height=max_image_size)
|
||||
|
||||
@@ -321,8 +221,9 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = processor.image_processor._image_to_num_tokens(
|
||||
Image.new("RGB", (image_size.width, image_size.height))
|
||||
_, nrows, ncols = processor.image_processor.get_number_of_image_patches(
|
||||
image_size.height,
|
||||
image_size.width,
|
||||
)
|
||||
|
||||
tokens = ([image_token_id] * ncols + [image_break_id]) * nrows
|
||||
|
||||
@@ -14,11 +14,7 @@ from typing import Annotated, Literal, TypeAlias
|
||||
import regex as re
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms import InterpolationMode
|
||||
from transformers import BatchFeature, PretrainedConfig, PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@@ -48,6 +44,7 @@ from vllm.multimodal.processing import (
|
||||
PromptUpdateDetails,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.processors.qwen_vl import QwenVLProcessor
|
||||
from vllm.utils.tensor_schema import TensorSchema, TensorShape
|
||||
|
||||
from .interfaces import (
|
||||
@@ -434,96 +431,16 @@ class QwenVLModel(QWenModel):
|
||||
)
|
||||
|
||||
|
||||
class QwenVLProcessor:
|
||||
"""
|
||||
This model doesn't define its own HF processor,
|
||||
so we implement our own one here.
|
||||
|
||||
We call the wrapped tokenizer to automatically insert image pad tokens:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/tokenization_qwen.py#L245
|
||||
|
||||
The image processor is defined here:
|
||||
https://huggingface.co/Qwen/Qwen-VL/blob/main/visual.py#L354
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
|
||||
config = self.get_hf_config()
|
||||
vision_config = config.visual
|
||||
image_size = vision_config["image_size"]
|
||||
|
||||
self.image_transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(
|
||||
(image_size, image_size),
|
||||
interpolation=InterpolationMode.BICUBIC,
|
||||
),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=(0.48145466, 0.4578275, 0.40821073),
|
||||
std=(0.26862954, 0.26130258, 0.27577711),
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@property
|
||||
def image_start_tag(self) -> str:
|
||||
return self.tokenizer.image_start_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_end_tag(self) -> str:
|
||||
return self.tokenizer.image_end_tag # type: ignore
|
||||
|
||||
@property
|
||||
def image_pad_tag(self) -> str:
|
||||
return self.tokenizer.image_pad_tag # type: ignore
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | list[ImageInput] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if images is None:
|
||||
images = []
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
text_inputs = self.tokenizer(text)
|
||||
|
||||
if len(images) == 0:
|
||||
image_inputs = {}
|
||||
else:
|
||||
pixel_values = [self.image_transform(image) for image in images]
|
||||
image_inputs = {"pixel_values": torch.stack(pixel_values)}
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
**text_inputs,
|
||||
**image_inputs,
|
||||
},
|
||||
tensor_type=return_tensors,
|
||||
)
|
||||
|
||||
|
||||
class QwenVLProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self, **kwargs: object) -> QwenVLProcessor:
|
||||
return self.ctx.init_processor(
|
||||
QwenVLProcessor,
|
||||
config=self.get_hf_config(),
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**kwargs,
|
||||
**{**kwargs, "image_size": image_size},
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
|
||||
@@ -3,25 +3,19 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property, partial
|
||||
from math import ceil
|
||||
from functools import partial
|
||||
from typing import Literal, cast
|
||||
|
||||
import numpy as np
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mistral_common.audio import mel_filter_bank
|
||||
from mistral_common.audio import Audio, mel_filter_bank
|
||||
from mistral_common.protocol.instruct.chunk import AudioChunk, RawAudio, TextChunk
|
||||
from mistral_common.protocol.instruct.messages import UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.protocol.transcription.request import TranscriptionRequest
|
||||
from mistral_common.tokens.tokenizers.audio import (
|
||||
Audio,
|
||||
AudioEncoder,
|
||||
)
|
||||
from transformers import BatchFeature, TensorType, WhisperConfig
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
from transformers import BatchFeature, WhisperConfig
|
||||
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@@ -62,6 +56,7 @@ from vllm.multimodal.processing.processor import (
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tokenizers import cached_tokenizer_from_config
|
||||
from vllm.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.transformers_utils.processors.voxtral import MistralCommonVoxtralProcessor
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsTranscription
|
||||
from .utils import init_vllm_registered_model, maybe_prefix
|
||||
@@ -81,98 +76,6 @@ ISO639_1_SUPPORTED_LANGS = {
|
||||
}
|
||||
|
||||
|
||||
class VoxtralProcessorAdapter:
|
||||
"""
|
||||
Provide a HF-compatible interface for
|
||||
:class:`mistral_common.tokens.tokenizers.multimodal.AudioEncoder`.
|
||||
"""
|
||||
|
||||
def __init__(self, tokenizer: MistralTokenizer) -> None:
|
||||
super().__init__()
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@cached_property
|
||||
def _audio_processor(self) -> AudioEncoder:
|
||||
audio_encoder = self.tokenizer.instruct.audio_encoder
|
||||
assert isinstance(audio_encoder, AudioEncoder)
|
||||
return audio_encoder
|
||||
|
||||
@cached_property
|
||||
def audio_token_id(self) -> int:
|
||||
return self._audio_processor.special_ids.audio
|
||||
|
||||
@cached_property
|
||||
def begin_audio_token_id(self) -> int:
|
||||
return self._audio_processor.special_ids.begin_audio
|
||||
|
||||
@cached_property
|
||||
def sampling_rate(self) -> int:
|
||||
return self._audio_processor.audio_config.sampling_rate
|
||||
|
||||
@cached_property
|
||||
def frame_rate(self) -> float:
|
||||
return self._audio_processor.audio_config.frame_rate
|
||||
|
||||
def get_num_audio_tokens(
|
||||
self,
|
||||
audio_length: int,
|
||||
) -> int:
|
||||
return ceil(audio_length / (self.sampling_rate // self.frame_rate))
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
audios: np.ndarray | list[np.ndarray] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
if text is None:
|
||||
text = []
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
if audios is None:
|
||||
audios = []
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
|
||||
if not audios:
|
||||
input_ids = self.tokenizer(text).input_ids
|
||||
return {"input_ids": torch.tensor(input_ids)}
|
||||
|
||||
# Allow dummy text, which is used for profiling as well as token inputs
|
||||
if any(len(t) > 0 for t in text):
|
||||
raise ValueError(
|
||||
"You've passed text inputs instead of token inputs. "
|
||||
"Make sure to process your input via `mistral_common`'s "
|
||||
"tokenizer or pass a chat completion request. "
|
||||
"For more info, see: "
|
||||
"https://github.com/vllm-project/vllm/issues/8411."
|
||||
)
|
||||
|
||||
audios_tokens = list[torch.Tensor]()
|
||||
audios_processed = list[torch.Tensor]()
|
||||
for audio in audios:
|
||||
assert isinstance(audio, np.ndarray)
|
||||
assert audio.ndim == 1
|
||||
|
||||
if not self._audio_processor.audio_config.is_streaming:
|
||||
audio = self._audio_processor.pad(audio, self.sampling_rate)
|
||||
|
||||
audio_tokens = [self.begin_audio_token_id] + [
|
||||
self.audio_token_id
|
||||
] * self.get_num_audio_tokens(len(audio))
|
||||
|
||||
audios_tokens.append(torch.tensor(audio_tokens))
|
||||
audios_processed.append(torch.tensor(audio))
|
||||
|
||||
return BatchFeature(
|
||||
{
|
||||
"input_ids": torch.cat(audios_tokens)[None].expand(len(text), -1),
|
||||
"audio_arrays": audios_processed,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||
def get_tokenizer(self) -> MistralTokenizer:
|
||||
tokenizer = cached_tokenizer_from_config(self.ctx.model_config)
|
||||
@@ -181,12 +84,18 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
return tokenizer
|
||||
|
||||
def get_hf_processor(self) -> VoxtralProcessorAdapter:
|
||||
return VoxtralProcessorAdapter(self.get_tokenizer())
|
||||
def get_hf_processor(self, **kwargs) -> MistralCommonVoxtralProcessor:
|
||||
return self.ctx.init_processor(
|
||||
MistralCommonVoxtralProcessor,
|
||||
tokenizer=self.get_tokenizer(),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_data_parser(self):
|
||||
feature_extractor = self.get_hf_processor().feature_extractor
|
||||
|
||||
return MultiModalDataParser(
|
||||
target_sr=self.get_hf_processor().sampling_rate,
|
||||
target_sr=feature_extractor.sampling_rate,
|
||||
target_channels=1,
|
||||
expected_hidden_size=self._get_expected_hidden_size(),
|
||||
)
|
||||
@@ -205,9 +114,10 @@ class VoxtralProcessingInfo(BaseProcessingInfo):
|
||||
return self.ctx.model_config.max_model_len
|
||||
|
||||
def get_max_audio_array_len(self) -> int:
|
||||
processor = self.get_hf_processor()
|
||||
feature_extractor = self.get_hf_processor().feature_extractor
|
||||
|
||||
return self.get_max_audio_tokens() * int(
|
||||
processor.sampling_rate // processor.frame_rate
|
||||
feature_extractor.sampling_rate // feature_extractor.frame_rate
|
||||
)
|
||||
|
||||
|
||||
@@ -242,6 +152,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
mm_options: Mapping[str, BaseDummyOptions],
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
feature_extractor = self.info.get_hf_processor().feature_extractor
|
||||
|
||||
dummy_text = self.get_dummy_text(mm_counts)
|
||||
dummy_mm_data = self.get_dummy_mm_data(seq_len, mm_counts, mm_options)
|
||||
@@ -252,7 +163,7 @@ class VoxtralDummyInputsBuilder(BaseDummyInputsBuilder[VoxtralProcessingInfo]):
|
||||
for audio in dummy_audios:
|
||||
audio_item = Audio(
|
||||
audio_array=audio,
|
||||
sampling_rate=self.info.get_hf_processor().sampling_rate,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
format=format,
|
||||
)
|
||||
chunk = AudioChunk(input_audio=RawAudio.from_audio(audio_item))
|
||||
@@ -292,33 +203,26 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
||||
# skip validation here
|
||||
...
|
||||
|
||||
def _apply_hf_processor_mm_only(
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
processor_data, passthrough_data = self._get_hf_mm_data(mm_items)
|
||||
audios = processor_data.get("audios", [])
|
||||
if not isinstance(audios, list):
|
||||
audios = [audios]
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
|
||||
audio_config = processor._audio_processor.audio_config
|
||||
audio_tensors: list[torch.Tensor] = []
|
||||
for audio in audios:
|
||||
audio = np.asarray(audio, dtype=np.float32).ravel()
|
||||
if not audio_config.is_streaming:
|
||||
audio = processor._audio_processor.pad(
|
||||
audio,
|
||||
processor.sampling_rate,
|
||||
audio_config.is_streaming,
|
||||
)
|
||||
audio_tensors.append(torch.tensor(audio))
|
||||
if audios:
|
||||
# MistralCommonVoxtralProcessor accepts "audio"
|
||||
mm_data["audio"] = audios
|
||||
|
||||
result = BatchFeature({"audio_arrays": audio_tensors} if audio_tensors else {})
|
||||
result.update(passthrough_data)
|
||||
return result
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
@@ -327,6 +231,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
||||
out_mm_kwargs: MultiModalKwargsItems,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
|
||||
feature_extractor = processor.feature_extractor
|
||||
|
||||
audio_id = processor.audio_token_id
|
||||
out_mm_data = out_mm_kwargs.require_data()
|
||||
@@ -348,7 +253,7 @@ class VoxtralMultiModalProcessor(BaseMultiModalProcessor[VoxtralProcessingInfo])
|
||||
audios = mm_items.get_items("audio", AudioProcessorItems)
|
||||
audio_len = audios.get_audio_length(item_idx)
|
||||
|
||||
nb_audio_tokens = processor.get_num_audio_tokens(audio_len)
|
||||
nb_audio_tokens = feature_extractor.get_num_audio_tokens(audio_len)
|
||||
|
||||
return [audio_id] * nb_audio_tokens
|
||||
|
||||
@@ -560,8 +465,8 @@ class VoxtralForConditionalGeneration(
|
||||
This is used for estimating the amount of processing for this audio.
|
||||
"""
|
||||
tokenizer = cached_tokenizer_from_config(model_config)
|
||||
adapter = VoxtralProcessorAdapter(tokenizer)
|
||||
return adapter.get_num_audio_tokens(
|
||||
adapter = MistralCommonVoxtralProcessor(tokenizer)
|
||||
return adapter.feature_extractor.get_num_audio_tokens(
|
||||
int(audio_duration_s * stt_config.sample_rate)
|
||||
)
|
||||
|
||||
|
||||
@@ -8,12 +8,13 @@ from typing import Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.chunk import RawAudio
|
||||
from mistral_common.protocol.transcription.request import (
|
||||
StreamingMode,
|
||||
TranscriptionRequest,
|
||||
)
|
||||
from mistral_common.tokens.tokenizers.audio import Audio, AudioConfig
|
||||
from mistral_common.tokens.tokenizers.audio import AudioConfig
|
||||
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
|
||||
Reference in New Issue
Block a user