[Misc] Abstract the logic for reading and writing media content (#11527)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -9,7 +9,7 @@ import pytest
|
||||
from PIL import Image, ImageChops
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
from vllm.multimodal.utils import (async_fetch_image, fetch_image,
|
||||
from vllm.multimodal.utils import (MediaConnector,
|
||||
repeat_and_pad_placeholder_tokens)
|
||||
|
||||
# Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA)
|
||||
@@ -23,7 +23,12 @@ TEST_IMAGE_URLS = [
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def url_images() -> Dict[str, Image.Image]:
|
||||
return {image_url: fetch_image(image_url) for image_url in TEST_IMAGE_URLS}
|
||||
connector = MediaConnector()
|
||||
|
||||
return {
|
||||
image_url: connector.fetch_image(image_url)
|
||||
for image_url in TEST_IMAGE_URLS
|
||||
}
|
||||
|
||||
|
||||
def get_supported_suffixes() -> Tuple[str, ...]:
|
||||
@@ -43,8 +48,10 @@ def _image_equals(a: Image.Image, b: Image.Image) -> bool:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_fetch_image_http(image_url: str):
|
||||
image_sync = fetch_image(image_url)
|
||||
image_async = await async_fetch_image(image_url)
|
||||
connector = MediaConnector()
|
||||
|
||||
image_sync = connector.fetch_image(image_url)
|
||||
image_async = await connector.fetch_image_async(image_url)
|
||||
assert _image_equals(image_sync, image_async)
|
||||
|
||||
|
||||
@@ -53,6 +60,7 @@ async def test_fetch_image_http(image_url: str):
|
||||
@pytest.mark.parametrize("suffix", get_supported_suffixes())
|
||||
async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
image_url: str, suffix: str):
|
||||
connector = MediaConnector()
|
||||
url_image = url_images[image_url]
|
||||
|
||||
try:
|
||||
@@ -75,48 +83,49 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image],
|
||||
base64_image = base64.b64encode(f.read()).decode("utf-8")
|
||||
data_url = f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
data_image_sync = fetch_image(data_url)
|
||||
data_image_sync = connector.fetch_image(data_url)
|
||||
if _image_equals(url_image, Image.open(f)):
|
||||
assert _image_equals(url_image, data_image_sync)
|
||||
else:
|
||||
pass # Lossy format; only check that image can be opened
|
||||
|
||||
data_image_async = await async_fetch_image(data_url)
|
||||
data_image_async = await connector.fetch_image_async(data_url)
|
||||
assert _image_equals(data_image_sync, data_image_async)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
async def test_fetch_image_local_files(image_url: str):
|
||||
connector = MediaConnector()
|
||||
|
||||
with TemporaryDirectory() as temp_dir:
|
||||
origin_image = fetch_image(image_url)
|
||||
local_connector = MediaConnector(allowed_local_media_path=temp_dir)
|
||||
|
||||
origin_image = connector.fetch_image(image_url)
|
||||
origin_image.save(os.path.join(temp_dir, os.path.basename(image_url)),
|
||||
quality=100,
|
||||
icc_profile=origin_image.info.get('icc_profile'))
|
||||
|
||||
image_async = await async_fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
|
||||
image_sync = fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
image_async = await local_connector.fetch_image_async(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}")
|
||||
image_sync = local_connector.fetch_image(
|
||||
f"file://{temp_dir}/{os.path.basename(image_url)}")
|
||||
# Check that the images are equal
|
||||
assert not ImageChops.difference(image_sync, image_async).getbbox()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
with pytest.raises(ValueError):
|
||||
await async_fetch_image(
|
||||
with pytest.raises(ValueError, match="must be a subpath"):
|
||||
await local_connector.fetch_image_async(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||
await connector.fetch_image_async(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}",
|
||||
allowed_local_media_path=temp_dir)
|
||||
with pytest.raises(ValueError):
|
||||
fetch_image(f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
with pytest.raises(ValueError, match="must be a subpath"):
|
||||
local_connector.fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
with pytest.raises(RuntimeError, match="Cannot load local files"):
|
||||
connector.fetch_image(
|
||||
f"file://{temp_dir}/../{os.path.basename(image_url)}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"])
|
||||
|
||||
Reference in New Issue
Block a user