add --insecure arg to the vllm bench to skip TLS (#34026)

Signed-off-by: Fan Yang <yan9fan@meta.com>
Co-authored-by: Fan Yang <yan9fan@meta.com>
This commit is contained in:
Fan Yang
2026-02-10 06:23:52 -08:00
committed by GitHub
parent d0bc520569
commit a1946570d8
2 changed files with 139 additions and 5 deletions

View File

@@ -1,15 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess
import tempfile
import time
from pathlib import Path
import pytest
import requests
import urllib3
from ..utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module")
def generate_self_signed_cert(cert_dir: Path) -> tuple[Path, Path]:
"""Generate a self-signed certificate for testing."""
cert_file = cert_dir / "cert.pem"
key_file = cert_dir / "key.pem"
# Generate self-signed certificate using openssl
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:2048",
"-keyout",
str(key_file),
"-out",
str(cert_file),
"-days",
"1",
"-nodes",
"-subj",
"/CN=localhost",
],
check=True,
capture_output=True,
)
return cert_file, key_file
class RemoteOpenAIServerSSL(RemoteOpenAIServer):
"""RemoteOpenAIServer subclass that supports SSL with self-signed certs."""
@property
def url_root(self) -> str:
return f"https://{self.host}:{self.port}"
def _wait_for_server(self, *, url: str, timeout: float):
"""Override to use HTTPS with SSL verification disabled."""
# Suppress InsecureRequestWarning for self-signed certs
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
start = time.time()
while True:
try:
if requests.get(url, verify=False).status_code == 200:
break
except Exception:
result = self._poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from None
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError("Server failed to start in time.") from None
@pytest.fixture(scope="function")
def server():
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
@@ -17,6 +78,27 @@ def server():
yield remote_server
@pytest.fixture(scope="function")
def ssl_server():
"""Start a vLLM server with SSL enabled using a self-signed certificate."""
with tempfile.TemporaryDirectory() as cert_dir:
cert_file, key_file = generate_self_signed_cert(Path(cert_dir))
args = [
"--max-model-len",
"1024",
"--enforce-eager",
"--load-format",
"dummy",
"--ssl-certfile",
str(cert_file),
"--ssl-keyfile",
str(key_file),
]
with RemoteOpenAIServerSSL(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.benchmark
def test_bench_serve(server):
# Test default model detection and input/output len
@@ -42,6 +124,31 @@ def test_bench_serve(server):
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark
def test_bench_serve_insecure(ssl_server):
"""Test --insecure flag with an HTTPS server using a self-signed certificate."""
base_url = f"https://{ssl_server.host}:{ssl_server.port}"
command = [
"vllm",
"bench",
"serve",
"--base-url",
base_url,
"--input-len",
"32",
"--output-len",
"4",
"--num-prompts",
"5",
"--insecure",
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark
def test_bench_serve_chat(server):
command = [

View File

@@ -26,6 +26,7 @@ import json
import os
import random
import shutil
import ssl
import time
import uuid
import warnings
@@ -60,11 +61,14 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
async def get_first_model_from_server(
base_url: str, headers: dict | None = None
base_url: str,
headers: dict | None = None,
ssl_context: ssl.SSLContext | bool | None = None,
) -> tuple[str, str]:
"""Fetch the first model from the server's /v1/models endpoint."""
models_url = f"{base_url}/v1/models"
async with aiohttp.ClientSession() as session:
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
try:
async with session.get(models_url, headers=headers) as response:
response.raise_for_status()
@@ -619,6 +623,7 @@ async def benchmark(
ramp_up_start_rps: int | None = None,
ramp_up_end_rps: int | None = None,
ready_check_timeout_sec: int = 600,
ssl_context: ssl.SSLContext | bool | None = None,
):
try:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
@@ -626,6 +631,8 @@ async def benchmark(
raise ValueError(f"Unknown backend: {endpoint_type}") from None
# Reuses connections across requests to reduce TLS handshake overhead.
# Use ssl_context if provided, otherwise default to True for https URLs
ssl_setting = ssl_context if ssl_context is not None else ("https://" in api_url)
connector = aiohttp.TCPConnector(
limit=max_concurrency or 0,
limit_per_host=max_concurrency or 0,
@@ -634,7 +641,7 @@ async def benchmark(
keepalive_timeout=60,
enable_cleanup_closed=True,
force_close=False,
ssl=("https://" in api_url),
ssl=ssl_setting,
)
session = aiohttp.ClientSession(
@@ -1513,6 +1520,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=None,
)
parser.add_argument(
"--insecure",
action="store_true",
default=False,
help="Disable SSL certificate verification. Use this option when "
"connecting to servers with self-signed certificates.",
)
def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args))
@@ -1564,10 +1579,21 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
else:
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
# SSL context configuration
ssl_context: ssl.SSLContext | bool | None = None
if args.insecure:
# Disable SSL certificate verification
ssl_context = False
elif "https://" in base_url:
# Use default SSL context for HTTPS
ssl_context = True
# Fetch model from server if not specified
if args.model is None:
print("Model not specified, fetching first model from server...")
model_name, model_id = await get_first_model_from_server(base_url, headers)
model_name, model_id = await get_first_model_from_server(
base_url, headers, ssl_context
)
print(f"First model name: {model_name}, first model id: {model_id}")
else:
model_name = args.served_model_name
@@ -1691,6 +1717,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps,
ready_check_timeout_sec=args.ready_check_timeout_sec,
ssl_context=ssl_context,
)
# Save config and results to json