- **Add SPDX license headers to python source files** - **Check for SPDX headers using pre-commit** commit 9d7ef44c3cfb72ca4c32e1c677d99259d10d4745 Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:18:24 2025 -0500 Add SPDX license headers to python source files This commit adds SPDX license headers to python source files as recommended to the project by the Linux Foundation. These headers provide a concise way that is both human and machine readable for communicating license information for each source file. It helps avoid any ambiguity about the license of the code and can also be easily used by tools to help manage license compliance. The Linux Foundation runs license scans against the codebase to help ensure we are in compliance with the licenses of the code we use, including dependencies. Having these headers in place helps that tool do its job. More information can be found on the SPDX site: - https://spdx.dev/learn/handling-license-info/ Signed-off-by: Russell Bryant <rbryant@redhat.com> commit 5a1cf1cb3b80759131c73f6a9dddebccac039dea Author: Russell Bryant <rbryant@redhat.com> Date: Fri Jan 31 14:36:32 2025 -0500 Check for SPDX headers using pre-commit Signed-off-by: Russell Bryant <rbryant@redhat.com> --------- Signed-off-by: Russell Bryant <rbryant@redhat.com>
343 lines
12 KiB
Python
343 lines
12 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
import asyncio
|
|
from http import HTTPStatus
|
|
from io import StringIO
|
|
from typing import Awaitable, Callable, List, Optional
|
|
|
|
import aiohttp
|
|
import torch
|
|
from prometheus_client import start_http_server
|
|
from tqdm import tqdm
|
|
|
|
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|
from vllm.entrypoints.logger import RequestLogger, logger
|
|
# yapf: disable
|
|
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
|
BatchRequestOutput,
|
|
BatchResponseData,
|
|
ChatCompletionResponse,
|
|
EmbeddingResponse, ErrorResponse,
|
|
ScoreResponse)
|
|
# yapf: enable
|
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
|
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
|
OpenAIServingModels)
|
|
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
|
|
from vllm.usage.usage_lib import UsageContext
|
|
from vllm.utils import FlexibleArgumentParser, random_uuid
|
|
from vllm.version import __version__ as VLLM_VERSION
|
|
|
|
|
|
def parse_args():
|
|
parser = FlexibleArgumentParser(
|
|
description="vLLM OpenAI-Compatible batch runner.")
|
|
parser.add_argument(
|
|
"-i",
|
|
"--input-file",
|
|
required=True,
|
|
type=str,
|
|
help=
|
|
"The path or url to a single input file. Currently supports local file "
|
|
"paths, or the http protocol (http or https). If a URL is specified, "
|
|
"the file should be available via HTTP GET.")
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output-file",
|
|
required=True,
|
|
type=str,
|
|
help="The path or url to a single output file. Currently supports "
|
|
"local file paths, or web (http or https) urls. If a URL is specified,"
|
|
" the file should be available via HTTP PUT.")
|
|
parser.add_argument("--response-role",
|
|
type=nullable_str,
|
|
default="assistant",
|
|
help="The role name to return if "
|
|
"`request.add_generation_prompt=True`.")
|
|
|
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
|
|
|
parser.add_argument('--max-log-len',
|
|
type=int,
|
|
default=None,
|
|
help='Max number of prompt characters or prompt '
|
|
'ID numbers being printed in log.'
|
|
'\n\nDefault: Unlimited')
|
|
|
|
parser.add_argument("--enable-metrics",
|
|
action="store_true",
|
|
help="Enable Prometheus metrics")
|
|
parser.add_argument(
|
|
"--url",
|
|
type=str,
|
|
default="0.0.0.0",
|
|
help="URL to the Prometheus metrics server "
|
|
"(only needed if enable-metrics is set).",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=8000,
|
|
help="Port number for the Prometheus metrics server "
|
|
"(only needed if enable-metrics is set).",
|
|
)
|
|
parser.add_argument(
|
|
"--enable-prompt-tokens-details",
|
|
action='store_true',
|
|
default=False,
|
|
help="If set to True, enable prompt_tokens_details in usage.")
|
|
|
|
return parser.parse_args()
|
|
|
|
|
|
# explicitly use pure text format, with a newline at the end
|
|
# this makes it impossible to see the animation in the progress bar
|
|
# but will avoid messing up with ray or multiprocessing, which wraps
|
|
# each line of output with some prefix.
|
|
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
|
|
|
|
|
class BatchProgressTracker:
|
|
|
|
def __init__(self):
|
|
self._total = 0
|
|
self._pbar: Optional[tqdm] = None
|
|
|
|
def submitted(self):
|
|
self._total += 1
|
|
|
|
def completed(self):
|
|
if self._pbar:
|
|
self._pbar.update()
|
|
|
|
def pbar(self) -> tqdm:
|
|
enable_tqdm = not torch.distributed.is_initialized(
|
|
) or torch.distributed.get_rank() == 0
|
|
self._pbar = tqdm(total=self._total,
|
|
unit="req",
|
|
desc="Running batch",
|
|
mininterval=5,
|
|
disable=not enable_tqdm,
|
|
bar_format=_BAR_FORMAT)
|
|
return self._pbar
|
|
|
|
|
|
async def read_file(path_or_url: str) -> str:
|
|
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
|
async with aiohttp.ClientSession() as session, \
|
|
session.get(path_or_url) as resp:
|
|
return await resp.text()
|
|
else:
|
|
with open(path_or_url, encoding="utf-8") as f:
|
|
return f.read()
|
|
|
|
|
|
async def write_file(path_or_url: str, data: str) -> None:
|
|
if path_or_url.startswith("http://") or path_or_url.startswith("https://"):
|
|
async with aiohttp.ClientSession() as session, \
|
|
session.put(path_or_url, data=data.encode("utf-8")):
|
|
pass
|
|
else:
|
|
# We should make this async, but as long as this is always run as a
|
|
# standalone program, blocking the event loop won't effect performance
|
|
# in this particular case.
|
|
with open(path_or_url, "w", encoding="utf-8") as f:
|
|
f.write(data)
|
|
|
|
|
|
def make_error_request_output(request: BatchRequestInput,
|
|
error_msg: str) -> BatchRequestOutput:
|
|
batch_output = BatchRequestOutput(
|
|
id=f"vllm-{random_uuid()}",
|
|
custom_id=request.custom_id,
|
|
response=BatchResponseData(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
request_id=f"vllm-batch-{random_uuid()}",
|
|
),
|
|
error=error_msg,
|
|
)
|
|
return batch_output
|
|
|
|
|
|
async def make_async_error_request_output(
|
|
request: BatchRequestInput, error_msg: str) -> BatchRequestOutput:
|
|
return make_error_request_output(request, error_msg)
|
|
|
|
|
|
async def run_request(serving_engine_func: Callable,
|
|
request: BatchRequestInput,
|
|
tracker: BatchProgressTracker) -> BatchRequestOutput:
|
|
response = await serving_engine_func(request.body)
|
|
|
|
if isinstance(response,
|
|
(ChatCompletionResponse, EmbeddingResponse, ScoreResponse)):
|
|
batch_output = BatchRequestOutput(
|
|
id=f"vllm-{random_uuid()}",
|
|
custom_id=request.custom_id,
|
|
response=BatchResponseData(
|
|
body=response, request_id=f"vllm-batch-{random_uuid()}"),
|
|
error=None,
|
|
)
|
|
elif isinstance(response, ErrorResponse):
|
|
batch_output = BatchRequestOutput(
|
|
id=f"vllm-{random_uuid()}",
|
|
custom_id=request.custom_id,
|
|
response=BatchResponseData(
|
|
status_code=response.code,
|
|
request_id=f"vllm-batch-{random_uuid()}"),
|
|
error=response,
|
|
)
|
|
else:
|
|
batch_output = make_error_request_output(
|
|
request, error_msg="Request must not be sent in stream mode")
|
|
|
|
tracker.completed()
|
|
return batch_output
|
|
|
|
|
|
async def main(args):
|
|
if args.served_model_name is not None:
|
|
served_model_names = args.served_model_name
|
|
else:
|
|
served_model_names = [args.model]
|
|
|
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
engine = AsyncLLMEngine.from_engine_args(
|
|
engine_args, usage_context=UsageContext.OPENAI_BATCH_RUNNER)
|
|
|
|
model_config = await engine.get_model_config()
|
|
base_model_paths = [
|
|
BaseModelPath(name=name, model_path=args.model)
|
|
for name in served_model_names
|
|
]
|
|
|
|
if args.disable_log_requests:
|
|
request_logger = None
|
|
else:
|
|
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
|
|
|
# Create the openai serving objects.
|
|
openai_serving_models = OpenAIServingModels(
|
|
engine_client=engine,
|
|
model_config=model_config,
|
|
base_model_paths=base_model_paths,
|
|
lora_modules=None,
|
|
prompt_adapters=None,
|
|
)
|
|
openai_serving_chat = OpenAIServingChat(
|
|
engine,
|
|
model_config,
|
|
openai_serving_models,
|
|
args.response_role,
|
|
request_logger=request_logger,
|
|
chat_template=None,
|
|
chat_template_content_format="auto",
|
|
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
|
) if model_config.runner_type == "generate" else None
|
|
openai_serving_embedding = OpenAIServingEmbedding(
|
|
engine,
|
|
model_config,
|
|
openai_serving_models,
|
|
request_logger=request_logger,
|
|
chat_template=None,
|
|
chat_template_content_format="auto",
|
|
) if model_config.task == "embed" else None
|
|
openai_serving_scores = (OpenAIServingScores(
|
|
engine,
|
|
model_config,
|
|
openai_serving_models,
|
|
request_logger=request_logger,
|
|
) if model_config.task == "score" else None)
|
|
|
|
tracker = BatchProgressTracker()
|
|
logger.info("Reading batch from %s...", args.input_file)
|
|
|
|
# Submit all requests in the file to the engine "concurrently".
|
|
response_futures: List[Awaitable[BatchRequestOutput]] = []
|
|
for request_json in (await read_file(args.input_file)).strip().split("\n"):
|
|
# Skip empty lines.
|
|
request_json = request_json.strip()
|
|
if not request_json:
|
|
continue
|
|
|
|
request = BatchRequestInput.model_validate_json(request_json)
|
|
|
|
# Determine the type of request and run it.
|
|
if request.url == "/v1/chat/completions":
|
|
handler_fn = (None if openai_serving_chat is None else
|
|
openai_serving_chat.create_chat_completion)
|
|
if handler_fn is None:
|
|
response_futures.append(
|
|
make_async_error_request_output(
|
|
request,
|
|
error_msg=
|
|
"The model does not support Chat Completions API",
|
|
))
|
|
continue
|
|
|
|
response_futures.append(run_request(handler_fn, request, tracker))
|
|
tracker.submitted()
|
|
elif request.url == "/v1/embeddings":
|
|
handler_fn = (None if openai_serving_embedding is None else
|
|
openai_serving_embedding.create_embedding)
|
|
if handler_fn is None:
|
|
response_futures.append(
|
|
make_async_error_request_output(
|
|
request,
|
|
error_msg="The model does not support Embeddings API",
|
|
))
|
|
continue
|
|
|
|
response_futures.append(run_request(handler_fn, request, tracker))
|
|
tracker.submitted()
|
|
elif request.url == "/v1/score":
|
|
handler_fn = (None if openai_serving_scores is None else
|
|
openai_serving_scores.create_score)
|
|
if handler_fn is None:
|
|
response_futures.append(
|
|
make_async_error_request_output(
|
|
request,
|
|
error_msg="The model does not support Scores API",
|
|
))
|
|
continue
|
|
|
|
response_futures.append(run_request(handler_fn, request, tracker))
|
|
tracker.submitted()
|
|
else:
|
|
response_futures.append(
|
|
make_async_error_request_output(
|
|
request,
|
|
error_msg=
|
|
"Only /v1/chat/completions, /v1/embeddings, and /v1/score "
|
|
"are supported in the batch endpoint.",
|
|
))
|
|
|
|
with tracker.pbar():
|
|
responses = await asyncio.gather(*response_futures)
|
|
|
|
output_buffer = StringIO()
|
|
for response in responses:
|
|
print(response.model_dump_json(), file=output_buffer)
|
|
|
|
output_buffer.seek(0)
|
|
await write_file(args.output_file, output_buffer.read().strip())
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
|
|
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
|
|
logger.info("args: %s", args)
|
|
|
|
# Start the Prometheus metrics server. LLMEngine uses the Prometheus client
|
|
# to publish metrics at the /metrics endpoint.
|
|
if args.enable_metrics:
|
|
logger.info("Prometheus metrics enabled")
|
|
start_http_server(port=args.port, addr=args.url)
|
|
else:
|
|
logger.info("Prometheus metrics disabled")
|
|
|
|
asyncio.run(main(args))
|