Split generic IO Processor plugins tests from Terratorch specific ones (#35756)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
@@ -1,154 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import io
|
||||
from collections.abc import Sequence
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import imagehash
|
||||
import pytest
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
|
||||
models_config = {
|
||||
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
|
||||
"image_url": "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff", # noqa: E501
|
||||
"out_hash": "aa6d92ad25926a5e",
|
||||
"plugin": "prithvi_to_tiff",
|
||||
},
|
||||
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars": {
|
||||
"image_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars/resolve/main/examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif", # noqa: E501
|
||||
"out_hash": "c07f4f602da73552",
|
||||
"plugin": "prithvi_to_tiff",
|
||||
},
|
||||
}
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
from vllm.renderers import BaseRenderer
|
||||
|
||||
|
||||
def _compute_image_hash(base64_data: str) -> str:
|
||||
# Decode the base64 output and create image from byte stream
|
||||
decoded_image = base64.b64decode(base64_data)
|
||||
image = Image.open(io.BytesIO(decoded_image))
|
||||
class DummyIOProcessor(IOProcessor):
|
||||
"""Minimal IOProcessor used as the target of the mocked plugin entry point."""
|
||||
|
||||
# Compute perceptual hash of the output image
|
||||
return str(imagehash.phash(image))
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: object,
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> PromptType | Sequence[PromptType]:
|
||||
raise NotImplementedError
|
||||
|
||||
def post_process(
|
||||
self,
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> object:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def my_plugin_entry_points():
|
||||
"""Patch importlib.metadata.entry_points to expose a single 'my_plugin'
|
||||
entry point backed by DummyIOProcessor, exercising the full plugin-loading
|
||||
code path: entry_points → plugin.load() → func() →
|
||||
resolve_obj_by_qualname → IOProcessor.__init__."""
|
||||
qualname = f"{DummyIOProcessor.__module__}.{DummyIOProcessor.__qualname__}"
|
||||
ep = MagicMock()
|
||||
ep.name = "my_plugin"
|
||||
ep.value = qualname
|
||||
ep.load.return_value = lambda: qualname
|
||||
with patch("importlib.metadata.entry_points", return_value=[ep]):
|
||||
yield
|
||||
|
||||
|
||||
def test_loading_missing_plugin():
|
||||
vllm_config = VllmConfig()
|
||||
renderer = MagicMock(spec=BaseRenderer)
|
||||
with pytest.raises(ValueError):
|
||||
get_io_processor(vllm_config, None, "wrong_plugin")
|
||||
get_io_processor(
|
||||
vllm_config, renderer=renderer, plugin_from_init="wrong_plugin"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def server(model_name, plugin):
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--enforce-eager",
|
||||
"--skip-tokenizer-init",
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--io-processor-plugin",
|
||||
plugin,
|
||||
"--enable-mm-embeds",
|
||||
]
|
||||
def test_loading_plugin(my_plugin_entry_points):
|
||||
# Plugin name supplied via plugin_from_init.
|
||||
vllm_config = MagicMock(spec=VllmConfig)
|
||||
renderer = MagicMock(spec=BaseRenderer)
|
||||
|
||||
with RemoteOpenAIServer(model_name, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, image_url, plugin, expected_hash",
|
||||
[
|
||||
(model_name, config["image_url"], config["plugin"], config["out_hash"])
|
||||
for model_name, config in models_config.items()
|
||||
],
|
||||
)
|
||||
async def test_prithvi_mae_plugin_online(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
image_url: str | dict,
|
||||
plugin: str,
|
||||
expected_hash: str,
|
||||
):
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": model_name,
|
||||
"softmax": False,
|
||||
}
|
||||
|
||||
ret = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json=request_payload_url,
|
||||
result = get_io_processor(
|
||||
vllm_config, renderer=renderer, plugin_from_init="my_plugin"
|
||||
)
|
||||
|
||||
response = ret.json()
|
||||
|
||||
# verify the request response is in the correct format
|
||||
assert (parsed_response := IOProcessorResponse(**response))
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
plugin_data = parsed_response.data
|
||||
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
|
||||
|
||||
# Compute the output image hash and compare it against the expected hash
|
||||
image_hash = _compute_image_hash(plugin_data["data"])
|
||||
assert image_hash == expected_hash, (
|
||||
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
|
||||
)
|
||||
assert isinstance(result, DummyIOProcessor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, image_url, plugin, expected_hash",
|
||||
[
|
||||
(model_name, config["image_url"], config["plugin"], config["out_hash"])
|
||||
for model_name, config in models_config.items()
|
||||
],
|
||||
)
|
||||
def test_prithvi_mae_plugin_offline(
|
||||
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
|
||||
):
|
||||
img_data = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
def test_loading_missing_plugin_from_model_config():
|
||||
# Build a mock VllmConfig whose hf_config advertises a plugin name,
|
||||
# exercising the model-config code path without loading a real model.
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.to_dict.return_value = {"io_processor_plugin": "wrong_plugin"}
|
||||
|
||||
prompt = dict(data=img_data)
|
||||
vllm_config = MagicMock(spec=VllmConfig)
|
||||
vllm_config.model_config.hf_config = mock_hf_config
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
skip_tokenizer_init=True,
|
||||
enable_mm_embeds=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin=plugin,
|
||||
default_torch_num_threads=1,
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(prompt, pooling_task="plugin")
|
||||
output = pooler_output[0].outputs
|
||||
renderer = MagicMock(spec=BaseRenderer)
|
||||
with pytest.raises(ValueError):
|
||||
get_io_processor(vllm_config, renderer=renderer)
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
|
||||
|
||||
# Compute the output image hash and compare it against the expected hash
|
||||
image_hash = _compute_image_hash(output.data)
|
||||
assert image_hash == expected_hash, (
|
||||
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
|
||||
)
|
||||
def test_loading_plugin_from_model_config(my_plugin_entry_points):
|
||||
# Plugin name supplied via the model's hf_config.
|
||||
mock_hf_config = MagicMock()
|
||||
mock_hf_config.to_dict.return_value = {"io_processor_plugin": "my_plugin"}
|
||||
|
||||
vllm_config = MagicMock(spec=VllmConfig)
|
||||
vllm_config.model_config.hf_config = mock_hf_config
|
||||
|
||||
renderer = MagicMock(spec=BaseRenderer)
|
||||
|
||||
result = get_io_processor(vllm_config, renderer=renderer)
|
||||
|
||||
assert isinstance(result, DummyIOProcessor)
|
||||
|
||||
Reference in New Issue
Block a user