Split generic IO Processor plugins tests from Terratorch specific ones (#35756)

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
This commit is contained in:
Christian Pinto
2026-03-04 16:01:03 +00:00
committed by GitHub
parent 18e01a0a10
commit 2f2212e6cc
4 changed files with 233 additions and 130 deletions

View File

@@ -1,154 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import io
from collections.abc import Sequence
from unittest.mock import MagicMock, patch
import imagehash
import pytest
import requests
from PIL import Image
from tests.utils import RemoteOpenAIServer
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import IOProcessorResponse
from vllm.inputs.data import PromptType
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
models_config = {
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL-Sen1Floods11": {
"image_url": "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff", # noqa: E501
"out_hash": "aa6d92ad25926a5e",
"plugin": "prithvi_to_tiff",
},
"ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars": {
"image_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-BurnScars/resolve/main/examples/subsetted_512x512_HLS.S30.T10SEH.2018190.v1.4_merged.tif", # noqa: E501
"out_hash": "c07f4f602da73552",
"plugin": "prithvi_to_tiff",
},
}
from vllm.plugins.io_processors.interface import IOProcessor
from vllm.renderers import BaseRenderer
def _compute_image_hash(base64_data: str) -> str:
# Decode the base64 output and create image from byte stream
decoded_image = base64.b64decode(base64_data)
image = Image.open(io.BytesIO(decoded_image))
class DummyIOProcessor(IOProcessor):
"""Minimal IOProcessor used as the target of the mocked plugin entry point."""
# Compute perceptual hash of the output image
return str(imagehash.phash(image))
def pre_process(
self,
prompt: object,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
raise NotImplementedError
def post_process(
self,
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> object:
raise NotImplementedError
@pytest.fixture
def my_plugin_entry_points():
"""Patch importlib.metadata.entry_points to expose a single 'my_plugin'
entry point backed by DummyIOProcessor, exercising the full plugin-loading
code path: entry_points → plugin.load() → func() →
resolve_obj_by_qualname → IOProcessor.__init__."""
qualname = f"{DummyIOProcessor.__module__}.{DummyIOProcessor.__qualname__}"
ep = MagicMock()
ep.name = "my_plugin"
ep.value = qualname
ep.load.return_value = lambda: qualname
with patch("importlib.metadata.entry_points", return_value=[ep]):
yield
def test_loading_missing_plugin():
vllm_config = VllmConfig()
renderer = MagicMock(spec=BaseRenderer)
with pytest.raises(ValueError):
get_io_processor(vllm_config, None, "wrong_plugin")
get_io_processor(
vllm_config, renderer=renderer, plugin_from_init="wrong_plugin"
)
@pytest.fixture(scope="function")
def server(model_name, plugin):
args = [
"--runner",
"pooling",
"--enforce-eager",
"--skip-tokenizer-init",
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
"--max-num-seqs",
"32",
"--io-processor-plugin",
plugin,
"--enable-mm-embeds",
]
def test_loading_plugin(my_plugin_entry_points):
# Plugin name supplied via plugin_from_init.
vllm_config = MagicMock(spec=VllmConfig)
renderer = MagicMock(spec=BaseRenderer)
with RemoteOpenAIServer(model_name, args) as remote_server:
yield remote_server
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
async def test_prithvi_mae_plugin_online(
server: RemoteOpenAIServer,
model_name: str,
image_url: str | dict,
plugin: str,
expected_hash: str,
):
request_payload_url = {
"data": {
"data": image_url,
"data_format": "url",
"image_format": "tiff",
"out_data_format": "b64_json",
},
"priority": 0,
"model": model_name,
"softmax": False,
}
ret = requests.post(
server.url_for("pooling"),
json=request_payload_url,
result = get_io_processor(
vllm_config, renderer=renderer, plugin_from_init="my_plugin"
)
response = ret.json()
# verify the request response is in the correct format
assert (parsed_response := IOProcessorResponse(**response))
# verify the output is formatted as expected for this plugin
plugin_data = parsed_response.data
assert all(plugin_data.get(attr) for attr in ["type", "format", "data"])
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(plugin_data["data"])
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
assert isinstance(result, DummyIOProcessor)
@pytest.mark.parametrize(
"model_name, image_url, plugin, expected_hash",
[
(model_name, config["image_url"], config["plugin"], config["out_hash"])
for model_name, config in models_config.items()
],
)
def test_prithvi_mae_plugin_offline(
vllm_runner, model_name: str, image_url: str | dict, plugin: str, expected_hash: str
):
img_data = dict(
data=image_url,
data_format="url",
image_format="tiff",
out_data_format="b64_json",
)
def test_loading_missing_plugin_from_model_config():
# Build a mock VllmConfig whose hf_config advertises a plugin name,
# exercising the model-config code path without loading a real model.
mock_hf_config = MagicMock()
mock_hf_config.to_dict.return_value = {"io_processor_plugin": "wrong_plugin"}
prompt = dict(data=img_data)
vllm_config = MagicMock(spec=VllmConfig)
vllm_config.model_config.hf_config = mock_hf_config
with vllm_runner(
model_name,
runner="pooling",
skip_tokenizer_init=True,
enable_mm_embeds=True,
enforce_eager=True,
# Limit the maximum number of parallel requests
# to avoid the model going OOM in CI.
max_num_seqs=32,
io_processor_plugin=plugin,
default_torch_num_threads=1,
) as llm_runner:
pooler_output = llm_runner.get_llm().encode(prompt, pooling_task="plugin")
output = pooler_output[0].outputs
renderer = MagicMock(spec=BaseRenderer)
with pytest.raises(ValueError):
get_io_processor(vllm_config, renderer=renderer)
# verify the output is formatted as expected for this plugin
assert all(hasattr(output, attr) for attr in ["type", "format", "data"])
# Compute the output image hash and compare it against the expected hash
image_hash = _compute_image_hash(output.data)
assert image_hash == expected_hash, (
f"Image hash mismatch: expected {expected_hash}, got {image_hash}"
)
def test_loading_plugin_from_model_config(my_plugin_entry_points):
# Plugin name supplied via the model's hf_config.
mock_hf_config = MagicMock()
mock_hf_config.to_dict.return_value = {"io_processor_plugin": "my_plugin"}
vllm_config = MagicMock(spec=VllmConfig)
vllm_config.model_config.hf_config = mock_hf_config
renderer = MagicMock(spec=BaseRenderer)
result = get_io_processor(vllm_config, renderer=renderer)
assert isinstance(result, DummyIOProcessor)