[Misc] IO Processor plugins for pooling models (#22820)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
137
tests/plugins_tests/test_io_processor_plugins.py
Normal file
137
tests/plugins_tests/test_io_processor_plugins.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.entrypoints.openai.protocol import IOProcessorResponse
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.pooling_params import PoolingParams
|
||||
|
||||
MODEL_NAME = "christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM"
|
||||
|
||||
image_url = "https://huggingface.co/christian-pinto/Prithvi-EO-2.0-300M-TL-VLLM/resolve/main/valencia_example_2024-10-26.tiff" # noqa: E501
|
||||
|
||||
|
||||
def test_loading_missing_plugin():
|
||||
vllm_config = VllmConfig()
|
||||
with pytest.raises(ValueError):
|
||||
get_io_processor(vllm_config, "wrong_plugin")
|
||||
|
||||
|
||||
def test_loading_engine_with_wrong_plugin():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LLM(
|
||||
model=MODEL_NAME,
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=32,
|
||||
io_processor_plugin="wrong_plugin",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
def test_prithvi_mae_plugin_offline(vllm_runner, model_name: str):
|
||||
|
||||
img_prompt = dict(
|
||||
data=image_url,
|
||||
data_format="url",
|
||||
image_format="tiff",
|
||||
out_data_format="b64_json",
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(task="encode", softmax=False)
|
||||
|
||||
with vllm_runner(
|
||||
model_name,
|
||||
runner="pooling",
|
||||
skip_tokenizer_init=True,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
max_num_seqs=1,
|
||||
io_processor_plugin="prithvi_to_tiff_valencia",
|
||||
) as llm_runner:
|
||||
pooler_output = llm_runner.get_llm().encode(
|
||||
img_prompt,
|
||||
pooling_params=pooling_params,
|
||||
)
|
||||
output = pooler_output[0].outputs
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
assert all(
|
||||
hasattr(output, attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(output.data)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
args = [
|
||||
"--runner",
|
||||
"pooling",
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--skip-tokenizer-init",
|
||||
# Limit the maximum number of parallel requests
|
||||
# to avoid the model going OOM in CI.
|
||||
"--max-num-seqs",
|
||||
"32",
|
||||
"--io-processor-plugin",
|
||||
"prithvi_to_tiff_valencia"
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_prithvi_mae_plugin_online(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
):
|
||||
|
||||
request_payload_url = {
|
||||
"data": {
|
||||
"data": image_url,
|
||||
"data_format": "url",
|
||||
"image_format": "tiff",
|
||||
"out_data_format": "b64_json",
|
||||
},
|
||||
"priority": 0,
|
||||
"model": model_name,
|
||||
}
|
||||
|
||||
ret = requests.post(
|
||||
server.url_for("pooling"),
|
||||
json=request_payload_url,
|
||||
)
|
||||
|
||||
response = ret.json()
|
||||
|
||||
# verify the request response is in the correct format
|
||||
assert (parsed_response := IOProcessorResponse(**response))
|
||||
|
||||
# verify the output is formatted as expected for this plugin
|
||||
plugin_data = parsed_response.data
|
||||
|
||||
assert all(
|
||||
plugin_data.get(attr)
|
||||
for attr in ["type", "format", "data", "request_id"])
|
||||
|
||||
# We just check that the output is a valid base64 string.
|
||||
# Raises an exception and fails the test if the string is corrupted.
|
||||
base64.b64decode(plugin_data["data"])
|
||||
Reference in New Issue
Block a user