[Plugin] Simplify IO Processor Plugin interface (#34236)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -18,18 +18,10 @@ from einops import rearrange
|
||||
from terratorch.datamodules import Sen1Floods11NonGeoDataModule
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.entrypoints.pooling.pooling.protocol import (
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
)
|
||||
from vllm.inputs.data import PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingRequestOutput
|
||||
from vllm.plugins.io_processors.interface import (
|
||||
IOProcessor,
|
||||
IOProcessorInput,
|
||||
IOProcessorOutput,
|
||||
)
|
||||
from vllm.plugins.io_processors.interface import IOProcessor
|
||||
|
||||
from .types import DataModuleConfig, ImagePrompt, ImageRequestOutput
|
||||
|
||||
@@ -227,7 +219,7 @@ def load_image(
|
||||
return imgs, temporal_coords, location_coords, metas
|
||||
|
||||
|
||||
class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
class PrithviMultimodalDataProcessor(IOProcessor[ImagePrompt, ImageRequestOutput]):
|
||||
indices = [0, 1, 2, 3, 4, 5]
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
@@ -251,34 +243,15 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
self.requests_cache: dict[str, dict[str, Any]] = {}
|
||||
self.indices = DEFAULT_INPUT_INDICES
|
||||
|
||||
def parse_request(self, request: Any) -> IOProcessorInput:
|
||||
if type(request) is dict:
|
||||
image_prompt = ImagePrompt(**request)
|
||||
return image_prompt
|
||||
if isinstance(request, IOProcessorRequest):
|
||||
if not hasattr(request, "data"):
|
||||
raise ValueError("missing 'data' field in OpenAIBaseModel Request")
|
||||
def parse_data(self, data: object) -> ImagePrompt:
|
||||
if isinstance(data, dict):
|
||||
return ImagePrompt(**data)
|
||||
|
||||
request_data = request.data
|
||||
|
||||
if type(request_data) is dict:
|
||||
return ImagePrompt(**request_data)
|
||||
else:
|
||||
raise ValueError("Unable to parse the request data")
|
||||
|
||||
raise ValueError("Unable to parse request")
|
||||
|
||||
def output_to_response(
|
||||
self, plugin_output: IOProcessorOutput
|
||||
) -> IOProcessorResponse:
|
||||
return IOProcessorResponse(
|
||||
request_id=plugin_output.request_id,
|
||||
data=plugin_output,
|
||||
)
|
||||
raise ValueError("Prompt data should be an `ImagePrompt`")
|
||||
|
||||
def pre_process(
|
||||
self,
|
||||
prompt: IOProcessorInput,
|
||||
prompt: ImagePrompt,
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> PromptType | Sequence[PromptType]:
|
||||
@@ -364,7 +337,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
model_output: Sequence[PoolingRequestOutput],
|
||||
request_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> IOProcessorOutput:
|
||||
) -> ImageRequestOutput:
|
||||
pred_imgs_list = []
|
||||
|
||||
if request_id and (request_id in self.requests_cache):
|
||||
@@ -409,5 +382,7 @@ class PrithviMultimodalDataProcessor(IOProcessor):
|
||||
)
|
||||
|
||||
return ImageRequestOutput(
|
||||
type=out_format, format="tiff", data=out_data, request_id=request_id
|
||||
type=out_format,
|
||||
format="tiff",
|
||||
data=out_data,
|
||||
)
|
||||
|
||||
@@ -38,9 +38,6 @@ class ImagePrompt(BaseModel):
|
||||
"""
|
||||
|
||||
|
||||
MultiModalPromptType = ImagePrompt
|
||||
|
||||
|
||||
class ImageRequestOutput(BaseModel):
|
||||
"""
|
||||
The output data of an image request to vLLM.
|
||||
@@ -54,4 +51,3 @@ class ImageRequestOutput(BaseModel):
|
||||
type: Literal["path", "b64_json"]
|
||||
format: str
|
||||
data: str
|
||||
request_id: str | None = None
|
||||
|
||||
Reference in New Issue
Block a user