[Misc] IO Processor plugins for pooling models (#22820)
Signed-off-by: Christian Pinto <christian.pinto@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from typing import Final, Literal, Optional, Union, cast
|
||||
|
||||
import jinja2
|
||||
@@ -13,19 +13,25 @@ import torch
|
||||
from fastapi import Request
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
IOProcessorRequest,
|
||||
IOProcessorResponse,
|
||||
PoolingChatRequest,
|
||||
PoolingCompletionRequest,
|
||||
PoolingRequest, PoolingResponse,
|
||||
PoolingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing, RequestPrompt
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.entrypoints.utils import _validate_truncation_size
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import PoolingOutput, PoolingRequestOutput
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.utils import merge_async_iterators
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -52,7 +58,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
models: OpenAIServingModels,
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
@@ -61,19 +67,21 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
log_error_stack: bool = False,
|
||||
) -> None:
|
||||
super().__init__(engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
model_config=vllm_config.model_config,
|
||||
models=models,
|
||||
request_logger=request_logger,
|
||||
log_error_stack=log_error_stack)
|
||||
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: Final = chat_template_content_format
|
||||
io_processor_plugin = self.model_config.io_processor_plugin
|
||||
self.io_processor = get_io_processor(vllm_config, io_processor_plugin)
|
||||
|
||||
async def create_pooling(
|
||||
self,
|
||||
request: PoolingRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
) -> Union[PoolingResponse, ErrorResponse]:
|
||||
) -> Union[PoolingResponse, IOProcessorResponse, ErrorResponse]:
|
||||
"""
|
||||
See https://platform.openai.com/docs/api-reference/embeddings/create
|
||||
for the API specification. This API mimics the OpenAI Embedding API.
|
||||
@@ -82,20 +90,13 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
encoding_format = request.encoding_format
|
||||
if request.dimensions is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
model_name = self._get_model_name(request.model)
|
||||
|
||||
request_id = f"pool-{self._base_request_id(raw_request)}"
|
||||
created_time = int(time.time())
|
||||
|
||||
truncate_prompt_tokens = request.truncate_prompt_tokens
|
||||
|
||||
is_io_processor_request = isinstance(request, IOProcessorRequest)
|
||||
try:
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
lora_request = self._maybe_get_adapters(request)
|
||||
|
||||
if self.model_config.skip_tokenizer_init:
|
||||
@@ -104,7 +105,32 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
tokenizer = await self.engine_client.get_tokenizer(lora_request
|
||||
)
|
||||
|
||||
if isinstance(request, PoolingChatRequest):
|
||||
if getattr(request, "dimensions", None) is not None:
|
||||
return self.create_error_response(
|
||||
"dimensions is currently not supported")
|
||||
|
||||
truncate_prompt_tokens = getattr(request, "truncate_prompt_tokens",
|
||||
None)
|
||||
truncate_prompt_tokens = _validate_truncation_size(
|
||||
self.max_model_len, truncate_prompt_tokens)
|
||||
|
||||
if is_io_processor_request:
|
||||
if self.io_processor is None:
|
||||
raise ValueError(
|
||||
"No IOProcessor plugin installed. Please refer "
|
||||
"to the documentation and to the "
|
||||
"'prithvi_geospatial_mae_io_processor' "
|
||||
"offline inference example for more details.")
|
||||
|
||||
validated_prompt = self.io_processor.parse_request(request)
|
||||
|
||||
engine_prompts = await self.io_processor.pre_process_async(
|
||||
prompt=validated_prompt, request_id=request_id)
|
||||
request_prompts: Sequence[RequestPrompt] = [
|
||||
""
|
||||
] * len(engine_prompts)
|
||||
|
||||
elif isinstance(request, PoolingChatRequest):
|
||||
(
|
||||
_,
|
||||
request_prompts,
|
||||
@@ -122,7 +148,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
continue_final_message=False,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
elif isinstance(request, PoolingCompletionRequest):
|
||||
(request_prompts,
|
||||
engine_prompts) = await self._preprocess_completion(
|
||||
request,
|
||||
@@ -130,6 +156,9 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request.input,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported request of type {type(request)}")
|
||||
except (ValueError, TypeError, jinja2.TemplateError) as e:
|
||||
logger.exception("Error in preprocessing prompt inputs")
|
||||
return self.create_error_response(str(e))
|
||||
@@ -171,6 +200,16 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
|
||||
result_generator = merge_async_iterators(*generators)
|
||||
|
||||
if is_io_processor_request:
|
||||
assert self.io_processor is not None
|
||||
output = await self.io_processor.post_process_async(
|
||||
model_output=result_generator,
|
||||
request_id=request_id,
|
||||
)
|
||||
return self.io_processor.output_to_response(output)
|
||||
|
||||
assert isinstance(request,
|
||||
(PoolingCompletionRequest, PoolingChatRequest))
|
||||
num_prompts = len(engine_prompts)
|
||||
|
||||
# Non-streaming response
|
||||
@@ -190,7 +229,7 @@ class OpenAIServingPooling(OpenAIServing):
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
encoding_format,
|
||||
request.encoding_format,
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
return self.create_error_response("Client disconnected")
|
||||
|
||||
Reference in New Issue
Block a user