Convert formatting to use ruff instead of yapf + isort (#26247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-05 15:06:22 +01:00
committed by GitHub
parent 17edd8a807
commit d6953beb91
1508 changed files with 115244 additions and 94146 deletions

View File

@@ -8,9 +8,12 @@ from unittest.mock import patch
import numpy as np
import pytest
from vllm.multimodal.audio import (AudioMediaIO, AudioResampler,
resample_audio_librosa,
resample_audio_scipy)
from vllm.multimodal.audio import (
AudioMediaIO,
AudioResampler,
resample_audio_librosa,
resample_audio_scipy,
)
@pytest.fixture
@@ -21,12 +24,10 @@ def dummy_audio():
def test_resample_audio_librosa(dummy_audio):
with patch("vllm.multimodal.audio.librosa.resample") as mock_resample:
mock_resample.return_value = dummy_audio * 2
out = resample_audio_librosa(dummy_audio,
orig_sr=44100,
target_sr=22050)
mock_resample.assert_called_once_with(dummy_audio,
orig_sr=44100,
target_sr=22050)
out = resample_audio_librosa(dummy_audio, orig_sr=44100, target_sr=22050)
mock_resample.assert_called_once_with(
dummy_audio, orig_sr=44100, target_sr=22050
)
assert np.all(out == dummy_audio * 2)
@@ -40,8 +41,7 @@ def test_resample_audio_scipy(dummy_audio):
assert np.all(out_same == dummy_audio)
@pytest.mark.xfail(
reason="resample_audio_scipy is buggy for non-integer ratios")
@pytest.mark.xfail(reason="resample_audio_scipy is buggy for non-integer ratios")
def test_resample_audio_scipy_non_integer_ratio(dummy_audio):
out = resample_audio_scipy(dummy_audio, orig_sr=5, target_sr=3)
@@ -54,13 +54,12 @@ def test_resample_audio_scipy_non_integer_ratio(dummy_audio):
def test_audio_resampler_librosa_calls_resample(dummy_audio):
resampler = AudioResampler(target_sr=22050, method="librosa")
with patch(
"vllm.multimodal.audio.resample_audio_librosa") as mock_resample:
with patch("vllm.multimodal.audio.resample_audio_librosa") as mock_resample:
mock_resample.return_value = dummy_audio
out = resampler.resample(dummy_audio, orig_sr=44100)
mock_resample.assert_called_once_with(dummy_audio,
orig_sr=44100,
target_sr=22050)
mock_resample.assert_called_once_with(
dummy_audio, orig_sr=44100, target_sr=22050
)
assert np.all(out == dummy_audio)
@@ -69,9 +68,9 @@ def test_audio_resampler_scipy_calls_resample(dummy_audio):
with patch("vllm.multimodal.audio.resample_audio_scipy") as mock_resample:
mock_resample.return_value = dummy_audio
out = resampler.resample(dummy_audio, orig_sr=44100)
mock_resample.assert_called_once_with(dummy_audio,
orig_sr=44100,
target_sr=22050)
mock_resample.assert_called_once_with(
dummy_audio, orig_sr=44100, target_sr=22050
)
assert np.all(out == dummy_audio)

View File

@@ -8,15 +8,20 @@ import torch
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.cache import (MultiModalCache,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
engine_receiver_cache_from_config,
processor_cache_from_config)
from vllm.multimodal.cache import (
MultiModalCache,
MultiModalProcessorCacheItem,
MultiModalProcessorCacheItemMetadata,
engine_receiver_cache_from_config,
processor_cache_from_config,
)
from vllm.multimodal.hasher import MultiModalHasher
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField)
from vllm.multimodal.inputs import (
MultiModalFieldElem,
MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField,
)
from vllm.multimodal.processing import PromptInsertion
pytestmark = pytest.mark.cpu_test
@@ -30,9 +35,9 @@ def _dummy_elem(
rng: Optional[np.random.RandomState] = None,
):
if rng is None:
data = torch.empty((size, ), dtype=torch.int8)
data = torch.empty((size,), dtype=torch.int8)
else:
data = torch.from_numpy(rng.randint(4, size=(size, ), dtype=np.int8))
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))
return MultiModalFieldElem(
modality=modality,
@@ -48,10 +53,9 @@ def _dummy_item(
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItem.from_elems([
_dummy_elem(modality, key, size, rng=rng)
for key, size in size_by_key.items()
])
return MultiModalKwargsItem.from_elems(
[_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()]
)
def _dummy_items(
@@ -59,10 +63,12 @@ def _dummy_items(
*,
rng: Optional[np.random.RandomState] = None,
):
return MultiModalKwargsItems.from_seq([
_dummy_item(modality, size_by_key, rng=rng)
for modality, size_by_key in size_by_key_modality.items()
])
return MultiModalKwargsItems.from_seq(
[
_dummy_item(modality, size_by_key, rng=rng)
for modality, size_by_key in size_by_key_modality.items()
]
)
# yapf: disable

View File

@@ -90,8 +90,6 @@ def test_hash_image_exif_id():
hasher = MultiModalHasher
# first image has UUID in ImageID, so it should hash to that UUID
assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs(
image=id.bytes)
assert hasher.hash_kwargs(image=image1) == hasher.hash_kwargs(image=id.bytes)
# second image has non-UUID in ImageID, so it should hash to the image data
assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs(
image=image2a)
assert hasher.hash_kwargs(image=image2) == hasher.hash_kwargs(image=image2a)

View File

@@ -43,8 +43,7 @@ def test_rgba_to_rgb():
def test_rgba_to_rgb_custom_background(tmp_path):
"""Test RGBA to RGB conversion with custom background colors."""
# Create a simple RGBA image with transparent and opaque pixels
rgba_image = Image.new("RGBA", (10, 10),
(255, 0, 0, 255)) # Red with full opacity
rgba_image = Image.new("RGBA", (10, 10), (255, 0, 0, 255)) # Red with full opacity
# Make top-left quadrant transparent
for i in range(5):
@@ -94,7 +93,7 @@ def test_rgba_to_rgb_custom_background(tmp_path):
assert blue_numpy[0][0][2] == 255 # B
# Test 4: Test with load_bytes method
with open(test_image_path, 'rb') as f:
with open(test_image_path, "rb") as f:
image_data = f.read()
image_io_green = ImageMediaIO(rgba_background_color=(0, 255, 0))
@@ -111,39 +110,47 @@ def test_rgba_background_color_validation():
"""Test that invalid rgba_background_color values are properly rejected."""
# Test invalid types
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color="255,255,255")
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color=255)
# Test wrong number of elements
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color=(255, 255))
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color=(255, 255, 255, 255))
# Test non-integer values
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color=(255.0, 255.0, 255.0))
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color=(255, "255", 255))
# Test out of range values
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color=(256, 255, 255))
with pytest.raises(ValueError,
match="rgba_background_color must be a list or tuple"):
with pytest.raises(
ValueError, match="rgba_background_color must be a list or tuple"
):
ImageMediaIO(rgba_background_color=(255, -1, 255))
# Test that valid values work

View File

@@ -9,8 +9,7 @@ from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors
pytestmark = pytest.mark.cpu_test
def assert_nested_tensors_equal(expected: NestedTensors,
actual: NestedTensors):
def assert_nested_tensors_equal(expected: NestedTensors, actual: NestedTensors):
assert type(expected) == type(actual) # noqa: E721
if isinstance(expected, torch.Tensor):
assert torch.equal(expected, actual)
@@ -19,8 +18,9 @@ def assert_nested_tensors_equal(expected: NestedTensors,
assert_nested_tensors_equal(expected_item, actual_item)
def assert_multimodal_inputs_equal(expected: MultiModalKwargs,
actual: MultiModalKwargs):
def assert_multimodal_inputs_equal(
expected: MultiModalKwargs, actual: MultiModalKwargs
):
assert set(expected.keys()) == set(actual.keys())
for key in expected:
assert_nested_tensors_equal(expected[key], actual[key])
@@ -52,19 +52,10 @@ def test_multimodal_input_batch_nested_tensors():
a = torch.rand([2, 3])
b = torch.rand([2, 3])
c = torch.rand([2, 3])
result = MultiModalKwargs.batch([{
"image": [a]
}, {
"image": [b]
}, {
"image": [c]
}])
assert_multimodal_inputs_equal(result, {
"image":
torch.stack([a.unsqueeze(0),
b.unsqueeze(0),
c.unsqueeze(0)])
})
result = MultiModalKwargs.batch([{"image": [a]}, {"image": [b]}, {"image": [c]}])
assert_multimodal_inputs_equal(
result, {"image": torch.stack([a.unsqueeze(0), b.unsqueeze(0), c.unsqueeze(0)])}
)
def test_multimodal_input_batch_heterogeneous_lists():
@@ -73,8 +64,8 @@ def test_multimodal_input_batch_heterogeneous_lists():
c = torch.rand([1, 2, 3])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(
result,
{"image": [torch.stack([a, b]), c.unsqueeze(0)]})
result, {"image": [torch.stack([a, b]), c.unsqueeze(0)]}
)
def test_multimodal_input_batch_multiple_batchable_lists():
@@ -84,9 +75,8 @@ def test_multimodal_input_batch_multiple_batchable_lists():
d = torch.rand([1, 2, 3])
result = MultiModalKwargs.batch([{"image": [a, b]}, {"image": [c, d]}])
assert_multimodal_inputs_equal(
result,
{"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])})
result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])}
)
def test_multimodal_input_batch_mixed_stacking_depths():

View File

@@ -9,16 +9,22 @@ import pytest
from vllm.config import ModelConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
# yapf conflicts with isort for this block
# yapf: disable
from vllm.multimodal.processing import (InputProcessingContext,
PlaceholderFeaturesInfo,
PromptIndexTargets, PromptInsertion,
PromptReplacement, apply_text_matches,
apply_token_matches,
find_mm_placeholders,
iter_token_matches,
replace_token_matches)
from vllm.multimodal.processing import (
InputProcessingContext,
PlaceholderFeaturesInfo,
PromptIndexTargets,
PromptInsertion,
PromptReplacement,
apply_text_matches,
apply_token_matches,
find_mm_placeholders,
iter_token_matches,
replace_token_matches,
)
# yapf: enable
from vllm.multimodal.profiling import MultiModalProfiler
from vllm.transformers_utils.tokenizer import AnyTokenizer

View File

@@ -19,22 +19,16 @@ pytestmark = pytest.mark.cpu_test
[
("Qwen/Qwen2-0.5B-Instruct", {}, False),
("Qwen/Qwen2.5-VL-3B-Instruct", {}, True),
("Qwen/Qwen2.5-VL-3B-Instruct", {
"image": 0,
"video": 0
}, False),
("Qwen/Qwen2.5-VL-3B-Instruct", {
"image": 0
}, True),
("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0, "video": 0}, False),
("Qwen/Qwen2.5-VL-3B-Instruct", {"image": 0}, True),
],
)
@pytest.mark.core_model
def test_supports_multimodal_inputs(model_id, limit_mm_per_prompt, expected):
"""Test supports_multimodal_inputs returns correct boolean for various
"""Test supports_multimodal_inputs returns correct boolean for various
configs."""
ctx = build_model_context(
model_id,
limit_mm_per_prompt=limit_mm_per_prompt,
)
assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(
ctx.model_config) is expected
assert MULTIMODAL_REGISTRY.supports_multimodal_inputs(ctx.model_config) is expected

View File

@@ -30,7 +30,6 @@ TEST_VIDEO_URLS = [
@pytest.fixture(scope="module")
def url_images(local_asset_server) -> dict[str, Image.Image]:
return {
image_url: local_asset_server.get_image_asset(image_url)
for image_url in TEST_IMAGE_ASSETS
@@ -39,10 +38,10 @@ def url_images(local_asset_server) -> dict[str, Image.Image]:
def get_supported_suffixes() -> tuple[str, ...]:
# We should at least test the file types mentioned in GPT-4 with Vision
OPENAI_SUPPORTED_SUFFIXES = ('.png', '.jpeg', '.jpg', '.webp', '.gif')
OPENAI_SUPPORTED_SUFFIXES = (".png", ".jpeg", ".jpg", ".webp", ".gif")
# Additional file types that are supported by us
EXTRA_SUPPORTED_SUFFIXES = ('.bmp', '.tiff')
EXTRA_SUPPORTED_SUFFIXES = (".bmp", ".tiff")
return OPENAI_SUPPORTED_SUFFIXES + EXTRA_SUPPORTED_SUFFIXES
@@ -64,14 +63,16 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.asyncio
@pytest.mark.parametrize("raw_image_url", TEST_IMAGE_ASSETS)
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: dict[str, Image.Image],
raw_image_url: str, suffix: str):
async def test_fetch_image_base64(
url_images: dict[str, Image.Image], raw_image_url: str, suffix: str
):
connector = MediaConnector(
# Domain restriction should not apply to data URLs.
allowed_media_domains=[
"www.bogotobogo.com",
"github.com",
])
]
)
url_image = url_images[raw_image_url]
try:
@@ -80,14 +81,14 @@ async def test_fetch_image_base64(url_images: dict[str, Image.Image],
try:
mime_type = mimetypes.types_map[suffix]
except KeyError:
pytest.skip('No MIME type')
pytest.skip("No MIME type")
with NamedTemporaryFile(suffix=suffix) as f:
try:
url_image.save(f.name)
except Exception as e:
if e.args[0] == 'cannot write mode RGBA as JPEG':
pytest.skip('Conversion not supported')
if e.args[0] == "cannot write mode RGBA as JPEG":
pytest.skip("Conversion not supported")
raise
@@ -113,30 +114,36 @@ async def test_fetch_image_local_files(image_url: str):
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
origin_image = connector.fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))
origin_image.save(
os.path.join(temp_dir, os.path.basename(image_url)),
quality=100,
icc_profile=origin_image.info.get("icc_profile"),
)
image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{os.path.basename(image_url)}")
f"file://{temp_dir}/{os.path.basename(image_url)}"
)
image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}")
f"file://{temp_dir}/{os.path.basename(image_url)}"
)
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()
with pytest.raises(ValueError, match="must be a subpath"):
await local_connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
f"file://{temp_dir}/../{os.path.basename(image_url)}"
)
with pytest.raises(RuntimeError, match="Cannot load local files"):
await connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
f"file://{temp_dir}/../{os.path.basename(image_url)}"
)
with pytest.raises(ValueError, match="must be a subpath"):
local_connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
f"file://{temp_dir}/../{os.path.basename(image_url)}"
)
with pytest.raises(RuntimeError, match="Cannot load local files"):
connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
connector.fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.asyncio
@@ -149,18 +156,19 @@ async def test_fetch_image_local_files_with_space_in_name(image_url: str):
origin_image = connector.fetch_image(image_url)
filename = "file name with space.jpg"
origin_image.save(os.path.join(temp_dir, filename),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))
origin_image.save(
os.path.join(temp_dir, filename),
quality=100,
icc_profile=origin_image.info.get("icc_profile"),
)
try:
image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{filename}")
image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{filename}")
f"file://{temp_dir}/{filename}"
)
image_sync = local_connector.fetch_image(f"file://{temp_dir}/{filename}")
except FileNotFoundError as e:
pytest.fail(
"Failed to fetch image with space in name: {}".format(e))
pytest.fail("Failed to fetch image with space in name: {}".format(e))
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()
@@ -183,9 +191,12 @@ async def test_fetch_image_error_conversion():
@pytest.mark.parametrize("num_frames", [-1, 32, 1800])
async def test_fetch_video_http(video_url: str, num_frames: int):
connector = MediaConnector(
media_io_kwargs={"video": {
"num_frames": num_frames,
}})
media_io_kwargs={
"video": {
"num_frames": num_frames,
}
}
)
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
@@ -198,8 +209,11 @@ async def test_fetch_video_http(video_url: str, num_frames: int):
@pytest.mark.parametrize("max_duration", [1, 60, 1800])
@pytest.mark.parametrize("requested_fps", [2, 24])
async def test_fetch_video_http_with_dynamic_loader(
video_url: str, max_duration: int, requested_fps: int,
monkeypatch: pytest.MonkeyPatch):
video_url: str,
max_duration: int,
requested_fps: int,
monkeypatch: pytest.MonkeyPatch,
):
with monkeypatch.context() as m:
m.setenv("VLLM_VIDEO_LOADER_BACKEND", "opencv_dynamic")
connector = MediaConnector(
@@ -208,11 +222,11 @@ async def test_fetch_video_http_with_dynamic_loader(
"max_duration": max_duration,
"requested_fps": requested_fps,
}
})
}
)
video_sync, metadata_sync = connector.fetch_video(video_url)
video_async, metadata_async = await connector.fetch_video_async(
video_url)
video_async, metadata_async = await connector.fetch_video_async(video_url)
assert np.array_equal(video_sync, video_async)
assert metadata_sync == metadata_async

View File

@@ -12,8 +12,7 @@ from PIL import Image
from vllm.assets.base import get_vllm_public_assets
from vllm.assets.video import video_to_ndarrays, video_to_pil_images_list
from vllm.multimodal.image import ImageMediaIO
from vllm.multimodal.video import (VIDEO_LOADER_REGISTRY, VideoLoader,
VideoMediaIO)
from vllm.multimodal.video import VIDEO_LOADER_REGISTRY, VideoLoader, VideoMediaIO
from .utils import cosine_similarity, create_video_from_image, normalize_image
@@ -26,7 +25,6 @@ FAKE_OUTPUT_2 = np.random.rand(NUM_FRAMES, 1280, 720, 3)
@VIDEO_LOADER_REGISTRY.register("test_video_loader_1")
class TestVideoLoader1(VideoLoader):
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
return FAKE_OUTPUT_1
@@ -34,7 +32,6 @@ class TestVideoLoader1(VideoLoader):
@VIDEO_LOADER_REGISTRY.register("test_video_loader_2")
class TestVideoLoader2(VideoLoader):
@classmethod
def load_bytes(cls, data: bytes, num_frames: int = -1) -> npt.NDArray:
return FAKE_OUTPUT_2
@@ -57,13 +54,10 @@ def test_video_loader_type_doesnt_exist():
@VIDEO_LOADER_REGISTRY.register("assert_10_frames_1_fps")
class Assert10Frames1FPSVideoLoader(VideoLoader):
@classmethod
def load_bytes(cls,
data: bytes,
num_frames: int = -1,
fps: float = -1.0,
**kwargs) -> npt.NDArray:
def load_bytes(
cls, data: bytes, num_frames: int = -1, fps: float = -1.0, **kwargs
) -> npt.NDArray:
assert num_frames == 10, "bad num_frames"
assert fps == 1.0, "bad fps"
return FAKE_OUTPUT_2
@@ -79,11 +73,8 @@ def test_video_media_io_kwargs(monkeypatch: pytest.MonkeyPatch):
_ = videoio.load_bytes(b"test")
videoio = VideoMediaIO(
imageio, **{
"num_frames": 10,
"fps": 1.0,
"not_used": "not_used"
})
imageio, **{"num_frames": 10, "fps": 1.0, "not_used": "not_used"}
)
_ = videoio.load_bytes(b"test")
with pytest.raises(AssertionError, match="bad num_frames"):
@@ -106,8 +97,9 @@ def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
Test all functions that use OpenCV for video I/O return RGB format.
Both RGB and grayscale videos are tested.
"""
image_path = get_vllm_public_assets(filename="stop_sign.jpg",
s3_prefix="vision_model_images")
image_path = get_vllm_public_assets(
filename="stop_sign.jpg", s3_prefix="vision_model_images"
)
image = Image.open(image_path)
with tempfile.TemporaryDirectory() as tmpdir:
if not is_color:
@@ -127,21 +119,24 @@ def test_opencv_video_io_colorspace(is_color: bool, fourcc: str, ext: str):
frames = video_to_ndarrays(video_path)
for frame in frames:
sim = cosine_similarity(normalize_image(np.array(frame)),
normalize_image(np.array(image)))
sim = cosine_similarity(
normalize_image(np.array(frame)), normalize_image(np.array(image))
)
assert np.sum(np.isnan(sim)) / sim.size < 0.001
assert np.nanmean(sim) > 0.99
pil_frames = video_to_pil_images_list(video_path)
for frame in pil_frames:
sim = cosine_similarity(normalize_image(np.array(frame)),
normalize_image(np.array(image)))
sim = cosine_similarity(
normalize_image(np.array(frame)), normalize_image(np.array(image))
)
assert np.sum(np.isnan(sim)) / sim.size < 0.001
assert np.nanmean(sim) > 0.99
io_frames, _ = VideoMediaIO(ImageMediaIO()).load_file(Path(video_path))
for frame in io_frames:
sim = cosine_similarity(normalize_image(np.array(frame)),
normalize_image(np.array(image)))
sim = cosine_similarity(
normalize_image(np.array(frame)), normalize_image(np.array(image))
)
assert np.sum(np.isnan(sim)) / sim.size < 0.001
assert np.nanmean(sim) > 0.99

View File

@@ -8,7 +8,7 @@ from PIL import Image
def random_image(rng: np.random.RandomState, min_wh: int, max_wh: int):
w, h = rng.randint(min_wh, max_wh, size=(2, ))
w, h = rng.randint(min_wh, max_wh, size=(2,))
arr = rng.randint(0, 255, size=(w, h, 3), dtype=np.uint8)
return Image.fromarray(arr)
@@ -21,7 +21,7 @@ def random_video(
max_wh: int,
):
num_frames = rng.randint(min_frames, max_frames)
w, h = rng.randint(min_wh, max_wh, size=(2, ))
w, h = rng.randint(min_wh, max_wh, size=(2,))
return rng.randint(0, 255, size=(num_frames, w, h, 3), dtype=np.uint8)
@@ -66,14 +66,13 @@ def create_video_from_image(
return video_path
def cosine_similarity(A: npt.NDArray,
B: npt.NDArray,
axis: int = -1) -> npt.NDArray:
def cosine_similarity(A: npt.NDArray, B: npt.NDArray, axis: int = -1) -> npt.NDArray:
"""Compute cosine similarity between two vectors."""
return (np.sum(A * B, axis=axis) /
(np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis)))
return np.sum(A * B, axis=axis) / (
np.linalg.norm(A, axis=axis) * np.linalg.norm(B, axis=axis)
)
def normalize_image(image: npt.NDArray) -> npt.NDArray:
"""Normalize image to [0, 1] range."""
return image.astype(np.float32) / 255.0
return image.astype(np.float32) / 255.0