add --insecure arg to the vllm bench to skip TLS (#34026)
Signed-off-by: Fan Yang <yan9fan@meta.com> Co-authored-by: Fan Yang <yan9fan@meta.com>
This commit is contained in:
@@ -1,15 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import subprocess
|
||||
import tempfile
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
import urllib3
|
||||
|
||||
from ..utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def generate_self_signed_cert(cert_dir: Path) -> tuple[Path, Path]:
|
||||
"""Generate a self-signed certificate for testing."""
|
||||
cert_file = cert_dir / "cert.pem"
|
||||
key_file = cert_dir / "key.pem"
|
||||
|
||||
# Generate self-signed certificate using openssl
|
||||
subprocess.run(
|
||||
[
|
||||
"openssl",
|
||||
"req",
|
||||
"-x509",
|
||||
"-newkey",
|
||||
"rsa:2048",
|
||||
"-keyout",
|
||||
str(key_file),
|
||||
"-out",
|
||||
str(cert_file),
|
||||
"-days",
|
||||
"1",
|
||||
"-nodes",
|
||||
"-subj",
|
||||
"/CN=localhost",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
return cert_file, key_file
|
||||
|
||||
|
||||
class RemoteOpenAIServerSSL(RemoteOpenAIServer):
|
||||
"""RemoteOpenAIServer subclass that supports SSL with self-signed certs."""
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return f"https://{self.host}:{self.port}"
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
"""Override to use HTTPS with SSL verification disabled."""
|
||||
# Suppress InsecureRequestWarning for self-signed certs
|
||||
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||
|
||||
start = time.time()
|
||||
while True:
|
||||
try:
|
||||
if requests.get(url, verify=False).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
result = self._poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Server exited unexpectedly.") from None
|
||||
|
||||
time.sleep(0.5)
|
||||
if time.time() - start > timeout:
|
||||
raise RuntimeError("Server failed to start in time.") from None
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def server():
|
||||
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
|
||||
|
||||
@@ -17,6 +78,27 @@ def server():
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def ssl_server():
|
||||
"""Start a vLLM server with SSL enabled using a self-signed certificate."""
|
||||
with tempfile.TemporaryDirectory() as cert_dir:
|
||||
cert_file, key_file = generate_self_signed_cert(Path(cert_dir))
|
||||
args = [
|
||||
"--max-model-len",
|
||||
"1024",
|
||||
"--enforce-eager",
|
||||
"--load-format",
|
||||
"dummy",
|
||||
"--ssl-certfile",
|
||||
str(cert_file),
|
||||
"--ssl-keyfile",
|
||||
str(key_file),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServerSSL(MODEL_NAME, args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve(server):
|
||||
# Test default model detection and input/output len
|
||||
@@ -42,6 +124,31 @@ def test_bench_serve(server):
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve_insecure(ssl_server):
|
||||
"""Test --insecure flag with an HTTPS server using a self-signed certificate."""
|
||||
base_url = f"https://{ssl_server.host}:{ssl_server.port}"
|
||||
command = [
|
||||
"vllm",
|
||||
"bench",
|
||||
"serve",
|
||||
"--base-url",
|
||||
base_url,
|
||||
"--input-len",
|
||||
"32",
|
||||
"--output-len",
|
||||
"4",
|
||||
"--num-prompts",
|
||||
"5",
|
||||
"--insecure",
|
||||
]
|
||||
result = subprocess.run(command, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
print(result.stderr)
|
||||
|
||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||
|
||||
|
||||
@pytest.mark.benchmark
|
||||
def test_bench_serve_chat(server):
|
||||
command = [
|
||||
|
||||
@@ -26,6 +26,7 @@ import json
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import ssl
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
@@ -60,11 +61,14 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
|
||||
|
||||
|
||||
async def get_first_model_from_server(
|
||||
base_url: str, headers: dict | None = None
|
||||
base_url: str,
|
||||
headers: dict | None = None,
|
||||
ssl_context: ssl.SSLContext | bool | None = None,
|
||||
) -> tuple[str, str]:
|
||||
"""Fetch the first model from the server's /v1/models endpoint."""
|
||||
models_url = f"{base_url}/v1/models"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
try:
|
||||
async with session.get(models_url, headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
@@ -619,6 +623,7 @@ async def benchmark(
|
||||
ramp_up_start_rps: int | None = None,
|
||||
ramp_up_end_rps: int | None = None,
|
||||
ready_check_timeout_sec: int = 600,
|
||||
ssl_context: ssl.SSLContext | bool | None = None,
|
||||
):
|
||||
try:
|
||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||
@@ -626,6 +631,8 @@ async def benchmark(
|
||||
raise ValueError(f"Unknown backend: {endpoint_type}") from None
|
||||
|
||||
# Reuses connections across requests to reduce TLS handshake overhead.
|
||||
# Use ssl_context if provided, otherwise default to True for https URLs
|
||||
ssl_setting = ssl_context if ssl_context is not None else ("https://" in api_url)
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=max_concurrency or 0,
|
||||
limit_per_host=max_concurrency or 0,
|
||||
@@ -634,7 +641,7 @@ async def benchmark(
|
||||
keepalive_timeout=60,
|
||||
enable_cleanup_closed=True,
|
||||
force_close=False,
|
||||
ssl=("https://" in api_url),
|
||||
ssl=ssl_setting,
|
||||
)
|
||||
|
||||
session = aiohttp.ClientSession(
|
||||
@@ -1513,6 +1520,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--insecure",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable SSL certificate verification. Use this option when "
|
||||
"connecting to servers with self-signed certificates.",
|
||||
)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace) -> dict[str, Any]:
|
||||
return asyncio.run(main_async(args))
|
||||
@@ -1564,10 +1579,21 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
else:
|
||||
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
|
||||
|
||||
# SSL context configuration
|
||||
ssl_context: ssl.SSLContext | bool | None = None
|
||||
if args.insecure:
|
||||
# Disable SSL certificate verification
|
||||
ssl_context = False
|
||||
elif "https://" in base_url:
|
||||
# Use default SSL context for HTTPS
|
||||
ssl_context = True
|
||||
|
||||
# Fetch model from server if not specified
|
||||
if args.model is None:
|
||||
print("Model not specified, fetching first model from server...")
|
||||
model_name, model_id = await get_first_model_from_server(base_url, headers)
|
||||
model_name, model_id = await get_first_model_from_server(
|
||||
base_url, headers, ssl_context
|
||||
)
|
||||
print(f"First model name: {model_name}, first model id: {model_id}")
|
||||
else:
|
||||
model_name = args.served_model_name
|
||||
@@ -1691,6 +1717,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||
ready_check_timeout_sec=args.ready_check_timeout_sec,
|
||||
ssl_context=ssl_context,
|
||||
)
|
||||
|
||||
# Save config and results to json
|
||||
|
||||
Reference in New Issue
Block a user