[Misc] Abstract the logic for reading and writing media content (#11527)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-12-27 19:21:23 +08:00
committed by GitHub
parent 2c9b8ea2b0
commit 7af553ea30
10 changed files with 495 additions and 389 deletions

View File

@@ -9,7 +9,7 @@ import pytest
from PIL import Image, ImageChops
from transformers import AutoConfig, AutoTokenizer
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
from vllm.multimodal.utils import (MediaConnector,
repeat_and_pad_placeholder_tokens)
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
@@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def url_images() -> Dict[str, Image.Image]:
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
connector = MediaConnector()
return {
image_url: connector.fetch_image(image_url)
for image_url in TEST_IMAGE_URLS
}
def get_supported_suffixes() -> Tuple[str, ...]:
@@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_http(image_url: str):
image_sync = fetch_image(image_url)
image_async = await async_fetch_image(image_url)
connector = MediaConnector()
image_sync = connector.fetch_image(image_url)
image_async = await connector.fetch_image_async(image_url)
assert _image_equals(image_sync, image_async)
@@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
@pytest.mark.parametrize("suffix", get_supported_suffixes())
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
image_url: str, suffix: str):
connector = MediaConnector()
url_image = url_images[image_url]
try:
@@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
base64_image = base64.b64encode(f.read()).decode("utf-8")
data_url = f"data:{mime_type};base64,{base64_image}"
data_image_sync = fetch_image(data_url)
data_image_sync = connector.fetch_image(data_url)
if _image_equals(url_image, Image.open(f)):
assert _image_equals(url_image, data_image_sync)
else:
pass # Lossy format; only check that image can be opened
data_image_async = await async_fetch_image(data_url)
data_image_async = await connector.fetch_image_async(data_url)
assert _image_equals(data_image_sync, data_image_async)
@pytest.mark.asyncio
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
async def test_fetch_image_local_files(image_url: str):
connector = MediaConnector()
with TemporaryDirectory() as temp_dir:
origin_image = fetch_image(image_url)
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
origin_image = connector.fetch_image(image_url)
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
quality=100,
icc_profile=origin_image.info.get('icc_profile'))
image_async = await async_fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
image_sync = fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
image_async = await local_connector.fetch_image_async(
f"file://{temp_dir}/{os.path.basename(image_url)}")
image_sync = local_connector.fetch_image(
f"file://{temp_dir}/{os.path.basename(image_url)}")
# Check that the images are equal
assert not ImageChops.difference(image_sync, image_async).getbbox()
with pytest.raises(ValueError):
await async_fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
await async_fetch_image(
with pytest.raises(ValueError, match="must be a subpath"):
await local_connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(RuntimeError, match="Cannot load local files"):
await connector.fetch_image_async(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
allowed_local_media_path=temp_dir)
with pytest.raises(ValueError):
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(ValueError, match="must be a subpath"):
local_connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
with pytest.raises(RuntimeError, match="Cannot load local files"):
connector.fetch_image(
f"file://{temp_dir}/../{os.path.basename(image_url)}")
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])