Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,22 +6,27 @@ from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import (ImageChunk, TextChunk,
|
||||
UserMessage)
|
||||
from mistral_common.protocol.instruct.messages import ImageChunk, TextChunk, UserMessage
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from PIL import Image
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.multimodal import (AudioDummyOptions, BaseDummyOptions,
|
||||
ImageDummyOptions, VideoDummyOptions)
|
||||
from vllm.config.multimodal import (
|
||||
AudioDummyOptions,
|
||||
BaseDummyOptions,
|
||||
ImageDummyOptions,
|
||||
VideoDummyOptions,
|
||||
)
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||
from vllm.multimodal.cache import MultiModalProcessorOnlyCache
|
||||
from vllm.multimodal.inputs import MultiModalInputs
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
InputProcessingContext)
|
||||
from vllm.transformers_utils.tokenizer import (AnyTokenizer, MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens)
|
||||
from vllm.multimodal.processing import BaseMultiModalProcessor, InputProcessingContext
|
||||
from vllm.transformers_utils.tokenizer import (
|
||||
AnyTokenizer,
|
||||
MistralTokenizer,
|
||||
cached_tokenizer_from_config,
|
||||
encode_tokens,
|
||||
)
|
||||
|
||||
from ....multimodal.utils import random_audio, random_image, random_video
|
||||
from ...registry import HF_EXAMPLE_MODELS
|
||||
@@ -36,14 +41,17 @@ def glm4_1v_patch_mm_data(mm_data: MultiModalDataDict) -> MultiModalDataDict:
|
||||
# GLM4.1V doesn't support multiple videos
|
||||
video = mm_data["video"]
|
||||
num_frames = len(video)
|
||||
mm_data["video"] = (video, {
|
||||
"total_num_frames": num_frames,
|
||||
"fps": num_frames,
|
||||
"duration": 1,
|
||||
"frames_indices": [i for i in range(num_frames)],
|
||||
"video_backend": "opencv",
|
||||
"do_sample_frames": True,
|
||||
})
|
||||
mm_data["video"] = (
|
||||
video,
|
||||
{
|
||||
"total_num_frames": num_frames,
|
||||
"fps": num_frames,
|
||||
"duration": 1,
|
||||
"frames_indices": [i for i in range(num_frames)],
|
||||
"video_backend": "opencv",
|
||||
"do_sample_frames": True,
|
||||
},
|
||||
)
|
||||
return mm_data
|
||||
|
||||
|
||||
@@ -102,7 +110,8 @@ def _test_processing_correctness(
|
||||
mm_processor_cache_gb=2048,
|
||||
skip_tokenizer_init=model_info.skip_tokenizer_init,
|
||||
enforce_eager=model_info.enforce_eager,
|
||||
dtype=model_info.dtype)
|
||||
dtype=model_info.dtype,
|
||||
)
|
||||
|
||||
model_cls = MULTIMODAL_REGISTRY._get_model_cls(model_config)
|
||||
factories = MULTIMODAL_REGISTRY._processor_factories[model_cls]
|
||||
@@ -145,27 +154,22 @@ def _test_processing_correctness(
|
||||
input_to_hit = {
|
||||
"image": Image.new("RGB", size=(128, 128)),
|
||||
"video": np.zeros((4, 128, 128, 3), dtype=np.uint8),
|
||||
"audio": (np.zeros((512, )), 16000),
|
||||
"audio": (np.zeros((512,)), 16000),
|
||||
}
|
||||
input_factory = {
|
||||
"image":
|
||||
partial(random_image, rng, min_wh=128, max_wh=256),
|
||||
"video":
|
||||
partial(random_video,
|
||||
rng,
|
||||
min_frames=2,
|
||||
max_frames=16,
|
||||
min_wh=128,
|
||||
max_wh=256),
|
||||
"audio":
|
||||
partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||
"image": partial(random_image, rng, min_wh=128, max_wh=256),
|
||||
"video": partial(
|
||||
random_video, rng, min_frames=2, max_frames=16, min_wh=128, max_wh=256
|
||||
),
|
||||
"audio": partial(random_audio, rng, min_len=512, max_len=1024, sr=16000),
|
||||
}
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
mm_data = {
|
||||
k:
|
||||
[(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||
for _ in range(rng.randint(limit + 1))]
|
||||
k: [
|
||||
(input_to_hit[k] if rng.rand() < hit_rate else input_factory[k]())
|
||||
for _ in range(rng.randint(limit + 1))
|
||||
]
|
||||
for k, limit in limit_mm_per_prompt_ints.items()
|
||||
}
|
||||
|
||||
@@ -174,12 +178,16 @@ def _test_processing_correctness(
|
||||
# Mistral chat outputs tokens directly, rather than text prompts
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
images = mm_data.get("image", [])
|
||||
request = ChatCompletionRequest(messages=[
|
||||
UserMessage(content=[
|
||||
TextChunk(text=""),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]),
|
||||
])
|
||||
request = ChatCompletionRequest(
|
||||
messages=[
|
||||
UserMessage(
|
||||
content=[
|
||||
TextChunk(text=""),
|
||||
*(ImageChunk(image=image) for image in images),
|
||||
]
|
||||
),
|
||||
]
|
||||
)
|
||||
res = tokenizer.mistral.encode_chat_completion(request)
|
||||
prompt = res.tokens
|
||||
else:
|
||||
@@ -303,16 +311,14 @@ def _test_processing_correctness_one(
|
||||
baseline_text_result,
|
||||
baseline_tokenized_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
|
||||
f"{token_prompt=}, {mm_data=})",
|
||||
msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
_assert_inputs_equal(
|
||||
cached_text_result,
|
||||
cached_tokenized_result,
|
||||
ignore_mm_keys=ignore_mm_keys,
|
||||
msg=f"Failed ({batch_idx=}, {text_prompt=}, "
|
||||
f"{token_prompt=}, {mm_data=})",
|
||||
msg=f"Failed ({batch_idx=}, {text_prompt=}, {token_prompt=}, {mm_data=})",
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user