[Bugfix] Fix profiling dummy data for Pixtral (#18677)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -9,15 +9,15 @@ from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.inputs import InputProcessingContext
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, ProcessingCache
|
||||
from vllm.transformers_utils.tokenizer import (MistralTokenizer,
|
||||
cached_tokenizer_from_config)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens)
|
||||
|
||||
from ....multimodal.utils import random_audio, random_image, random_video
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
@@ -28,7 +28,6 @@ def _test_processing_correctness(
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
|
||||
model_info.check_available_online(on_fail="skip")
|
||||
@@ -99,10 +98,23 @@ def _test_processing_correctness(
|
||||
}
|
||||
|
||||
mm_counts = {k: len(vs) for k, vs in mm_data.items()}
|
||||
prompt = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
).prompt_text
|
||||
|
||||
# Mistral chat outputs tokens directly, rather than text prompts
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
images = mm_data.get("image", [])
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[
|
||||
TextChunk(text=""),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]),
|
||||
])
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
prompt = res.tokens
|
||||
else:
|
||||
prompt = dummy_inputs.get_dummy_processor_inputs(
|
||||
model_config.max_model_len,
|
||||
mm_counts,
|
||||
).prompt
|
||||
|
||||
# Drop unnecessary keys and test single -> multi conversion
|
||||
if rng.rand() < simplify_rate:
|
||||
@@ -112,67 +124,59 @@ def _test_processing_correctness(
|
||||
elif len(mm_data[k]) == 1:
|
||||
mm_data[k] = mm_data[k][0]
|
||||
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
_test_processing_correctness_mistral(
|
||||
model_config,
|
||||
tokenizer,
|
||||
prompt,
|
||||
mm_data,
|
||||
baseline_processor,
|
||||
cached_processor,
|
||||
batch_idx,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
)
|
||||
else:
|
||||
_test_processing_correctness_hf(
|
||||
model_config,
|
||||
tokenizer,
|
||||
prompt,
|
||||
mm_data,
|
||||
baseline_processor,
|
||||
cached_processor,
|
||||
batch_idx,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
)
|
||||
_test_processing_correctness_one(
|
||||
model_config,
|
||||
tokenizer,
|
||||
prompt,
|
||||
mm_data,
|
||||
baseline_processor,
|
||||
cached_processor,
|
||||
batch_idx,
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_correctness_hf(
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
# at the beginning of prompt by default, causing hf_processor outputs
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
_ADD_SPECIAL_TOKENS_OVERRIDES = {
|
||||
"mllama": False,
|
||||
"ovis": False,
|
||||
"ultravox": False,
|
||||
"whisper": False,
|
||||
}
|
||||
|
||||
_IGNORE_MM_KEYS = {
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
# attention_mask lets us ignore the difference.
|
||||
"ultravox": {"audio_features"},
|
||||
}
|
||||
|
||||
|
||||
def _test_processing_correctness_one(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
prompt: str,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
if model_config.hf_config.model_type in ("mllama", "ovis", "ultravox",
|
||||
"whisper"):
|
||||
# For some multimodal models, tokenizer will always add bos_token
|
||||
# at the beginning of prompt by default, causing hf_processor outputs
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
token_prompt = tokenizer.encode(prompt, add_special_tokens=False)
|
||||
model_type = model_config.hf_config.model_type
|
||||
ignore_mm_keys = _IGNORE_MM_KEYS.get(model_type, set[str]())
|
||||
|
||||
if isinstance(prompt, str):
|
||||
text_prompt = prompt
|
||||
token_prompt = encode_tokens(
|
||||
tokenizer,
|
||||
prompt,
|
||||
add_special_tokens=_ADD_SPECIAL_TOKENS_OVERRIDES.get(model_type),
|
||||
)
|
||||
else:
|
||||
token_prompt = tokenizer.encode(prompt)
|
||||
|
||||
baseline_result = baseline_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_result = cached_processor.apply(
|
||||
prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
cached_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
# Mistral does not support decode_tokens with skip_special_tokens=False
|
||||
text_prompt = None
|
||||
token_prompt = prompt
|
||||
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
token_prompt,
|
||||
@@ -180,56 +184,6 @@ def _test_processing_correctness_hf(
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
_assert_inputs_equal(
|
||||
baseline_result,
|
||||
baseline_tokenized_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
token_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
_assert_inputs_equal(
|
||||
cached_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
def _test_processing_correctness_mistral(
|
||||
model_config: ModelConfig,
|
||||
tokenizer: MistralTokenizer,
|
||||
prompt: str,
|
||||
mm_data: MultiModalDataDict,
|
||||
baseline_processor: BaseMultiModalProcessor,
|
||||
cached_processor: BaseMultiModalProcessor,
|
||||
batch_idx: int,
|
||||
ignore_mm_keys: Optional[set[str]] = None,
|
||||
):
|
||||
images = mm_data.get("image", [])
|
||||
if not isinstance(images, list):
|
||||
images = [images]
|
||||
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[
|
||||
TextChunk(text=prompt),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]),
|
||||
])
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
token_prompt = res.tokens
|
||||
|
||||
# Mistral chat outputs tokens directly, rather than text prompts
|
||||
baseline_tokenized_result = baseline_processor.apply(
|
||||
token_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_tokenized_result = cached_processor.apply(
|
||||
token_prompt,
|
||||
mm_data=mm_data,
|
||||
@@ -240,9 +194,44 @@ def _test_processing_correctness_mistral(
|
||||
baseline_tokenized_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {prompt=}, {mm_data=})",
|
||||
msg=f"Failed ({batch_idx=}, {token_prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
if text_prompt is not None:
|
||||
baseline_text_result = baseline_processor.apply(
|
||||
text_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
cached_text_result = cached_processor.apply(
|
||||
text_prompt,
|
||||
mm_data=mm_data,
|
||||
hf_processor_mm_kwargs={},
|
||||
)
|
||||
|
||||
_assert_inputs_equal(
|
||||
baseline_text_result,
|
||||
cached_text_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {text_prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
_assert_inputs_equal(
|
||||
baseline_text_result,
|
||||
baseline_tokenized_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
|
||||
f"{token_prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
_assert_inputs_equal(
|
||||
cached_text_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
|
||||
f"{token_prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize("model_id", [
|
||||
@@ -281,6 +270,7 @@ def _test_processing_correctness_mistral(
|
||||
"AIDC-AI/Ovis2-1B",
|
||||
"google/paligemma-3b-mix-224",
|
||||
"google/paligemma2-3b-ft-docci-448",
|
||||
"microsoft/Phi-3.5-vision-instruct",
|
||||
"microsoft/Phi-4-multimodal-instruct",
|
||||
"mistralai/Pixtral-12B-2409",
|
||||
"mistral-community/pixtral-12b",
|
||||
@@ -303,41 +293,6 @@ def test_processing_correctness(
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
ignore_mm_keys = None
|
||||
if 'ultravox' in model_id:
|
||||
# In Ultravox, the audio_features can be different depending on padding
|
||||
# The slight difference should not be a problem though, since
|
||||
# attention_mask lets us ignore the difference.
|
||||
ignore_mm_keys = {"audio_features"}
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
hit_rate=hit_rate,
|
||||
num_batches=num_batches,
|
||||
simplify_rate=simplify_rate,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize("model_id", ["microsoft/Phi-3.5-vision-instruct"])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
@pytest.mark.parametrize("simplify_rate", [1.0])
|
||||
# yapf: enable
|
||||
def test_processing_correctness_phi3v(
|
||||
model_id: str,
|
||||
hit_rate: float,
|
||||
num_batches: int,
|
||||
simplify_rate: float,
|
||||
):
|
||||
# HACK - this is an attempted workaround for the following bug
|
||||
# https://github.com/huggingface/transformers/issues/34307
|
||||
from transformers import AutoImageProcessor # noqa: F401
|
||||
from transformers import AutoProcessor # noqa: F401
|
||||
|
||||
AutoImageProcessor.from_pretrained(model_id, trust_remote_code=True)
|
||||
|
||||
_test_processing_correctness(
|
||||
model_id,
|
||||
hit_rate=hit_rate,
|
||||
@@ -356,16 +311,10 @@ def _assert_inputs_equal(
|
||||
if ignore_mm_keys is None:
|
||||
ignore_mm_keys = set()
|
||||
|
||||
if msg is None:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b
|
||||
else:
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
||||
assert "mm_kwargs" in a and "mm_kwargs" in b, msg
|
||||
|
||||
for key in ignore_mm_keys:
|
||||
a["mm_kwargs"].pop(key, None)
|
||||
b["mm_kwargs"].pop(key, None)
|
||||
|
||||
if msg is None:
|
||||
assert a == b
|
||||
else:
|
||||
assert a == b, msg
|
||||
assert a == b, msg
|
||||
|
||||
Reference in New Issue
Block a user