[VLM] Merged multi-modal processor for Pixtral (#12211)
Signed-off-by: remi <remi@mistral.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -2,17 +2,23 @@
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.processing import ProcessingCache
|
||||
from vllm.transformers_utils.tokenizer import cached_tokenizer_from_config
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
|
||||
from ....multimodal.utils import random_audio, random_image, random_video
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
@@ -85,14 +91,6 @@ def _test_processing_correctness(
|
||||
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||
}
|
||||
|
||||
tokenizer_encode_kwargs = {}
|
||||
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
# at the beginning of prompt by default, causing hf_processor outputs
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
tokenizer_encode_kwargs = {"add_special_tokens": False}
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
mm_data = {
|
||||
k:
|
||||
@@ -115,43 +113,131 @@ def _test_processing_correctness(
|
||||
elif len(mm_data[k]) == 1:
|
||||
mm_data[k] = mm_data[k][0]
|
||||
|
||||
baseline_result = baseline_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_result = cached_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
_test_processing_correctness_mistral(
|
||||
model_config,
|
||||
tokenizer,
|
||||
prompt,
|
||||
mm_data,
|
||||
baseline_processor,
|
||||
cached_processor,
|
||||
batch_idx,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
)
|
||||
else:
|
||||
_test_processing_correctness_hf(
|
||||
model_config,
|
||||
tokenizer,
|
||||
prompt,
|
||||
mm_data,
|
||||
baseline_processor,
|
||||
cached_processor,
|
||||
batch_idx,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
)
|
||||
|
||||
assert _drop_mm_kwargs_keys(
|
||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
cached_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
def _test_processing_correctness_hf(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
):
|
||||
if model_config.hf_config.model_type in ("mllama", "whisper", "ultravox"):
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
# at the beginning of prompt by default, causing hf_processor outputs
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
else:
|
||||
token_prompt = tokenizer.encode(prompt)
|
||||
|
||||
assert _drop_mm_kwargs_keys(
|
||||
baseline_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
baseline_tokenized_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
baseline_result = baseline_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_result = cached_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
tokenizer.encode(prompt, **tokenizer_encode_kwargs),
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
assert _inputs_equal(
|
||||
baseline_result,
|
||||
cached_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
|
||||
assert _drop_mm_kwargs_keys(
|
||||
cached_result, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
cached_tokenized_result, ignore_mm_keys), (
|
||||
f"Failed ({batch_idx=}, {prompt=}, {mm_data=})")
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
token_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
baseline_result,
|
||||
baseline_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
token_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
cached_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
|
||||
|
||||
def _test_processing_correctness_mistral(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: MistralTokenizer,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
):
|
||||
images = mm_data.get("image", [])
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[
|
||||
TextChunk(text=prompt),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]),
|
||||
])
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
token_prompt = res.tokens
|
||||
|
||||
# Mistral chat outputs tokens directly, rather than text prompts
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
token_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
token_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
assert _inputs_equal(
|
||||
baseline_tokenized_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys,
|
||||
), f"Failed ({batch_idx=}, {prompt=}, {mm_data=})"
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@@ -173,6 +259,7 @@ def _test_processing_correctness(
|
||||
"llava-hf/llava-onevision-qwen2-0.5b-ov-hf",
|
||||
"meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"TIGER-Lab/Mantis-8B-siglip-llama3",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
"openbmb/MiniCPM-o-2_6",
|
||||
"openbmb/MiniCPM-V-2_6",
|
||||
@@ -241,8 +328,19 @@ def test_processing_correctness_phi3v(
|
||||
)
|
||||
|
||||
|
||||
def _drop_mm_kwargs_keys(result: dict,
|
||||
ignore_mm_keys: Optional[list[str]] = None) -> dict:
|
||||
def _inputs_equal(
|
||||
a: MultiModalInputs,
|
||||
b: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
):
|
||||
return _drop_mm_kwargs_keys(a, ignore_mm_keys) == _drop_mm_kwargs_keys(
|
||||
b, ignore_mm_keys)
|
||||
|
||||
|
||||
def _drop_mm_kwargs_keys(
|
||||
result: MultiModalInputs,
|
||||
ignore_mm_keys: Optional[list[str]] = None,
|
||||
) -> MultiModalInputs:
|
||||
"""Drop specified keys from result['mm_kwargs'].
|
||||
|
||||
This is mainly to avoid doing exact match of audio_features in ultravox.
|
||||
|
||||
Reference in New Issue
Block a user