[Frontend] Support embeddings in the run_batch API (#7132)
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@@ -1,18 +1,21 @@
|
||||
import asyncio
|
||||
from io import StringIO
|
||||
from typing import Awaitable, List
|
||||
from typing import Awaitable, Callable, List
|
||||
|
||||
import aiohttp
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
ChatCompletionResponse,
|
||||
ErrorResponse)
|
||||
EmbeddingResponse, ErrorResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import FlexibleArgumentParser, random_uuid
|
||||
@@ -82,27 +85,26 @@ async def write_file(path_or_url: str, data: str) -> None:
|
||||
f.write(data)
|
||||
|
||||
|
||||
async def run_request(chat_serving: OpenAIServingChat,
|
||||
async def run_request(serving_engine_func: Callable,
|
||||
request: BatchRequestInput) -> BatchRequestOutput:
|
||||
chat_request = request.body
|
||||
chat_response = await chat_serving.create_chat_completion(chat_request)
|
||||
response = await serving_engine_func(request.body)
|
||||
|
||||
if isinstance(chat_response, ChatCompletionResponse):
|
||||
if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
body=chat_response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=None,
|
||||
)
|
||||
elif isinstance(chat_response, ErrorResponse):
|
||||
elif isinstance(response, ErrorResponse):
|
||||
batch_output = BatchRequestOutput(
|
||||
id=f"vllm-{random_uuid()}",
|
||||
custom_id=request.custom_id,
|
||||
response=BatchResponseData(
|
||||
status_code=chat_response.code,
|
||||
status_code=response.code,
|
||||
request_id=f"vllm-batch-{random_uuid()}"),
|
||||
error=chat_response,
|
||||
error=response,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Request must not be sent in stream mode")
|
||||
@@ -128,6 +130,7 @@ async def main(args):
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
@@ -138,12 +141,35 @@ async def main(args):
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
||||
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
||||
# Skip empty lines.
|
||||
request_json = request_json.strip()
|
||||
if not request_json:
|
||||
continue
|
||||
|
||||
request = BatchRequestInput.model_validate_json(request_json)
|
||||
response_futures.append(run_request(openai_serving_chat, request))
|
||||
|
||||
# Determine the type of request and run it.
|
||||
if request.url == "/v1/chat/completions":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_chat.create_chat_completion,
|
||||
request))
|
||||
elif request.url == "/v1/embeddings":
|
||||
response_futures.append(
|
||||
run_request(openai_serving_embedding.create_embedding,
|
||||
request))
|
||||
else:
|
||||
raise ValueError("Only /v1/chat/completions and /v1/embeddings are"
|
||||
"supported in the batch endpoint.")
|
||||
|
||||
responses = await asyncio.gather(*response_futures)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user