[Frontend] run-batch supports V1 (#21541)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from collections.abc import Awaitable
|
||||
from http import HTTPStatus
|
||||
from io import StringIO
|
||||
@@ -13,10 +14,12 @@ import torch
|
||||
from prometheus_client import start_http_server
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.api_server import build_async_engine_client
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
@@ -310,36 +313,37 @@ async def run_request(serving_engine_func: Callable,
|
||||
return batch_output
|
||||
|
||||
|
||||
async def main(args):
|
||||
async def run_batch(
|
||||
engine_client: EngineClient,
|
||||
vllm_config: VllmConfig,
|
||||
args: Namespace,
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
engine = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
||||
|
||||
model_config = await engine.get_model_config()
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
base_model_paths = [
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
]
|
||||
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
)
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
args.response_role,
|
||||
@@ -349,7 +353,7 @@ async def main(args):
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
) if "generate" in model_config.supported_tasks else None
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
@@ -362,7 +366,7 @@ async def main(args):
|
||||
"num_labels", 0) == 1)
|
||||
|
||||
openai_serving_scores = ServingScores(
|
||||
engine,
|
||||
engine_client,
|
||||
model_config,
|
||||
openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
@@ -457,6 +461,17 @@ async def main(args):
|
||||
await write_file(args.output_file, responses, args.output_tmp_dir)
|
||||
|
||||
|
||||
async def main(args: Namespace):
|
||||
async with build_async_engine_client(
|
||||
args,
|
||||
usage_context=UsageContext.OPENAI_BATCH_RUNNER,
|
||||
disable_frontend_multiprocessing=False,
|
||||
) as engine_client:
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
|
||||
await run_batch(engine_client, vllm_config, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user