[Misc] IO Processor plugins for pooling models (#22820)

Signed-off-by: Christian Pinto <christian.pinto@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Christian Pinto
2025-09-01 07:07:12 +01:00
committed by GitHub
parent 437c3ce026
commit 1cb39dbcdd
25 changed files with 1183 additions and 43 deletions

View File

@@ -4,7 +4,7 @@
import asyncio
import base64
import time
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Sequence
from typing import Final, Literal, Optional, Union, cast
import jinja2
@@ -13,19 +13,25 @@ import torch
from fastapi import Request
from typing_extensions import assert_never
from vllm.config import ModelConfig
from vllm.config import VllmConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (ErrorResponse,
IOProcessorRequest,
IOProcessorResponse,
PoolingChatRequest,
PoolingCompletionRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
# yapf: enable
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.plugins.io_processors import get_io_processor
from vllm.utils import merge_async_iterators
logger = init_logger(__name__)
@@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
vllm_config: VllmConfig,
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
@@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
log_error_stack: bool = False,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
model_config=vllm_config.model_config,
models=models,
request_logger=request_logger,
log_error_stack=log_error_stack)
self.chat_template = chat_template
self.chat_template_content_format: Final = chat_template_content_format
io_processor_plugin = self.model_config.io_processor_plugin
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
async def create_pooling(
self,
request: PoolingRequest,
raw_request: Optional[Request] = None,
) -> Union[PoolingResponse, ErrorResponse]:
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
"""
See https://platform.openai.com/docs/api-reference/embeddings/create
for the API specification. This API mimics the OpenAI Embedding API.
@@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
if error_check_ret is not None:
return error_check_ret
encoding_format = request.encoding_format
if request.dimensions is not None:
return self.create_error_response(
"dimensions is currently not supported")
model_name = self._get_model_name(request.model)
request_id = f"pool-{self._base_request_id(raw_request)}"
created_time = int(time.time())
truncate_prompt_tokens = request.truncate_prompt_tokens
is_io_processor_request = isinstance(request, IOProcessorRequest)
try:
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
lora_request = self._maybe_get_adapters(request)
if self.model_config.skip_tokenizer_init:
@@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
tokenizer = await self.engine_client.get_tokenizer(lora_request
)
if isinstance(request, PoolingChatRequest):
if getattr(request, "dimensions", None) is not None:
return self.create_error_response(
"dimensions is currently not supported")
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
None)
truncate_prompt_tokens = _validate_truncation_size(
self.max_model_len, truncate_prompt_tokens)
if is_io_processor_request:
if self.io_processor is None:
raise ValueError(
"No IOProcessor plugin installed. Please refer "
"to the documentation and to the "
"'prithvi_geospatial_mae_io_processor' "
"offline inference example for more details.")
validated_prompt = self.io_processor.parse_request(request)
engine_prompts = await self.io_processor.pre_process_async(
prompt=validated_prompt, request_id=request_id)
request_prompts: Sequence[RequestPrompt] = [
""
] * len(engine_prompts)
elif isinstance(request, PoolingChatRequest):
(
_,
request_prompts,
@@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
continue_final_message=False,
add_special_tokens=request.add_special_tokens,
)
else:
elif isinstance(request, PoolingCompletionRequest):
(request_prompts,
engine_prompts) = await self._preprocess_completion(
request,
@@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
request.input,
add_special_tokens=request.add_special_tokens,
)
else:
raise ValueError(
f"Unsupported request of type {type(request)}")
except (ValueError, TypeError, jinja2.TemplateError) as e:
logger.exception("Error in preprocessing prompt inputs")
return self.create_error_response(str(e))
@@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
result_generator = merge_async_iterators(*generators)
if is_io_processor_request:
assert self.io_processor is not None
output = await self.io_processor.post_process_async(
model_output=result_generator,
request_id=request_id,
)
return self.io_processor.output_to_response(output)
assert isinstance(request,
(PoolingCompletionRequest, PoolingChatRequest))
num_prompts = len(engine_prompts)
# Non-streaming response
@@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
request_id,
created_time,
model_name,
encoding_format,
request.encoding_format,
)
except asyncio.CancelledError:
return self.create_error_response("Client disconnected")