[Plugin] Simplify IO Processor Plugin interface (#34236)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-02-11 11:47:39 +08:00
committed by GitHub
parent b482f71e9f
commit c9a1923bb4
9 changed files with 164 additions and 148 deletions

View File

@@ -18,18 +18,10 @@ from einops import rearrange
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
from vllm.config import VllmConfig
from vllm.entrypoints.pooling.pooling.protocol import (
IOProcessorRequest,
IOProcessorResponse,
)
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput
from vllm.plugins.io_processors.interface import (
IOProcessor,
IOProcessorInput,
IOProcessorOutput,
)
from vllm.plugins.io_processors.interface import IOProcessor
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
@@ -227,7 +219,7 @@ def load_image(
return imgs, temporal_coords, location_coords, metas
class PrithviMultimodalDataProcessor(IOProcessor):
class PrithviMultimodalDataProcessor(IOProcessor[ImagePrompt, ImageRequestOutput]):
indices = [0, 1, 2, 3, 4, 5]
def __init__(self, vllm_config: VllmConfig):
@@ -251,34 +243,15 @@ class PrithviMultimodalDataProcessor(IOProcessor):
self.requests_cache: dict[str, dict[str, Any]] = {}
self.indices = DEFAULT_INPUT_INDICES
def parse_request(self, request: Any) -> IOProcessorInput:
if type(request) is dict:
image_prompt = ImagePrompt(**request)
return image_prompt
if isinstance(request, IOProcessorRequest):
if not hasattr(request, "data"):
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
def parse_data(self, data: object) -> ImagePrompt:
if isinstance(data, dict):
return ImagePrompt(**data)
request_data = request.data
if type(request_data) is dict:
return ImagePrompt(**request_data)
else:
raise ValueError("Unable to parse the request data")
raise ValueError("Unable to parse request")
def output_to_response(
self, plugin_output: IOProcessorOutput
) -> IOProcessorResponse:
return IOProcessorResponse(
request_id=plugin_output.request_id,
data=plugin_output,
)
raise ValueError("Prompt data should be an `ImagePrompt`")
def pre_process(
self,
prompt: IOProcessorInput,
prompt: ImagePrompt,
request_id: str | None = None,
**kwargs,
) -> PromptType | Sequence[PromptType]:
@@ -364,7 +337,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
model_output: Sequence[PoolingRequestOutput],
request_id: str | None = None,
**kwargs,
) -> IOProcessorOutput:
) -> ImageRequestOutput:
pred_imgs_list = []
if request_id and (request_id in self.requests_cache):
@@ -409,5 +382,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
)
return ImageRequestOutput(
type=out_format, format="tiff", data=out_data, request_id=request_id
type=out_format,
format="tiff",
data=out_data,
)

View File

@@ -38,9 +38,6 @@ class ImagePrompt(BaseModel):
"""
MultiModalPromptType = ImagePrompt
class ImageRequestOutput(BaseModel):
"""
The output data of an image request to vLLM.
@@ -54,4 +51,3 @@ class ImageRequestOutput(BaseModel):
type: Literal["path", "b64_json"]
format: str
data: str
request_id: str | None = None