add --insecure arg to the vllm bench to skip TLS (#34026)

Signed-off-by: Fan Yang <yan9fan@meta.com>
Co-authored-by: Fan Yang <yan9fan@meta.com>
This commit is contained in:
Fan Yang
2026-02-10 06:23:52 -08:00
committed by GitHub
parent d0bc520569
commit a1946570d8
2 changed files with 139 additions and 5 deletions

View File

@@ -1,15 +1,76 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import subprocess import subprocess
import tempfile
import time
from pathlib import Path
import pytest import pytest
import requests
import urllib3
from ..utils import RemoteOpenAIServer from ..utils import RemoteOpenAIServer
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct" MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
@pytest.fixture(scope="module") def generate_self_signed_cert(cert_dir: Path) -> tuple[Path, Path]:
"""Generate a self-signed certificate for testing."""
cert_file = cert_dir / "cert.pem"
key_file = cert_dir / "key.pem"
# Generate self-signed certificate using openssl
subprocess.run(
[
"openssl",
"req",
"-x509",
"-newkey",
"rsa:2048",
"-keyout",
str(key_file),
"-out",
str(cert_file),
"-days",
"1",
"-nodes",
"-subj",
"/CN=localhost",
],
check=True,
capture_output=True,
)
return cert_file, key_file
class RemoteOpenAIServerSSL(RemoteOpenAIServer):
"""RemoteOpenAIServer subclass that supports SSL with self-signed certs."""
@property
def url_root(self) -> str:
return f"https://{self.host}:{self.port}"
def _wait_for_server(self, *, url: str, timeout: float):
"""Override to use HTTPS with SSL verification disabled."""
# Suppress InsecureRequestWarning for self-signed certs
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
start = time.time()
while True:
try:
if requests.get(url, verify=False).status_code == 200:
break
except Exception:
result = self._poll()
if result is not None and result != 0:
raise RuntimeError("Server exited unexpectedly.") from None
time.sleep(0.5)
if time.time() - start > timeout:
raise RuntimeError("Server failed to start in time.") from None
@pytest.fixture(scope="function")
def server(): def server():
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"] args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
@@ -17,6 +78,27 @@ def server():
yield remote_server yield remote_server
@pytest.fixture(scope="function")
def ssl_server():
"""Start a vLLM server with SSL enabled using a self-signed certificate."""
with tempfile.TemporaryDirectory() as cert_dir:
cert_file, key_file = generate_self_signed_cert(Path(cert_dir))
args = [
"--max-model-len",
"1024",
"--enforce-eager",
"--load-format",
"dummy",
"--ssl-certfile",
str(cert_file),
"--ssl-keyfile",
str(key_file),
]
with RemoteOpenAIServerSSL(MODEL_NAME, args) as remote_server:
yield remote_server
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_serve(server): def test_bench_serve(server):
# Test default model detection and input/output len # Test default model detection and input/output len
@@ -42,6 +124,31 @@ def test_bench_serve(server):
assert result.returncode == 0, f"Benchmark failed: {result.stderr}" assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark
def test_bench_serve_insecure(ssl_server):
"""Test --insecure flag with an HTTPS server using a self-signed certificate."""
base_url = f"https://{ssl_server.host}:{ssl_server.port}"
command = [
"vllm",
"bench",
"serve",
"--base-url",
base_url,
"--input-len",
"32",
"--output-len",
"4",
"--num-prompts",
"5",
"--insecure",
]
result = subprocess.run(command, capture_output=True, text=True)
print(result.stdout)
print(result.stderr)
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
@pytest.mark.benchmark @pytest.mark.benchmark
def test_bench_serve_chat(server): def test_bench_serve_chat(server):
command = [ command = [

View File

@@ -26,6 +26,7 @@ import json
import os import os
import random import random
import shutil import shutil
import ssl
import time import time
import uuid import uuid
import warnings import warnings
@@ -60,11 +61,14 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
async def get_first_model_from_server( async def get_first_model_from_server(
base_url: str, headers: dict | None = None base_url: str,
headers: dict | None = None,
ssl_context: ssl.SSLContext | bool | None = None,
) -> tuple[str, str]: ) -> tuple[str, str]:
"""Fetch the first model from the server's /v1/models endpoint.""" """Fetch the first model from the server's /v1/models endpoint."""
models_url = f"{base_url}/v1/models" models_url = f"{base_url}/v1/models"
async with aiohttp.ClientSession() as session: connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
try: try:
async with session.get(models_url, headers=headers) as response: async with session.get(models_url, headers=headers) as response:
response.raise_for_status() response.raise_for_status()
@@ -619,6 +623,7 @@ async def benchmark(
ramp_up_start_rps: int | None = None, ramp_up_start_rps: int | None = None,
ramp_up_end_rps: int | None = None, ramp_up_end_rps: int | None = None,
ready_check_timeout_sec: int = 600, ready_check_timeout_sec: int = 600,
ssl_context: ssl.SSLContext | bool | None = None,
): ):
try: try:
request_func = ASYNC_REQUEST_FUNCS[endpoint_type] request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
@@ -626,6 +631,8 @@ async def benchmark(
raise ValueError(f"Unknown backend: {endpoint_type}") from None raise ValueError(f"Unknown backend: {endpoint_type}") from None
# Reuses connections across requests to reduce TLS handshake overhead. # Reuses connections across requests to reduce TLS handshake overhead.
# Use ssl_context if provided, otherwise default to True for https URLs
ssl_setting = ssl_context if ssl_context is not None else ("https://" in api_url)
connector = aiohttp.TCPConnector( connector = aiohttp.TCPConnector(
limit=max_concurrency or 0, limit=max_concurrency or 0,
limit_per_host=max_concurrency or 0, limit_per_host=max_concurrency or 0,
@@ -634,7 +641,7 @@ async def benchmark(
keepalive_timeout=60, keepalive_timeout=60,
enable_cleanup_closed=True, enable_cleanup_closed=True,
force_close=False, force_close=False,
ssl=("https://" in api_url), ssl=ssl_setting,
) )
session = aiohttp.ClientSession( session = aiohttp.ClientSession(
@@ -1513,6 +1520,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
default=None, default=None,
) )
parser.add_argument(
"--insecure",
action="store_true",
default=False,
help="Disable SSL certificate verification. Use this option when "
"connecting to servers with self-signed certificates.",
)
def main(args: argparse.Namespace) -> dict[str, Any]: def main(args: argparse.Namespace) -> dict[str, Any]:
return asyncio.run(main_async(args)) return asyncio.run(main_async(args))
@@ -1564,10 +1579,21 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
else: else:
raise ValueError("Invalid header format. Please use KEY=VALUE format.") raise ValueError("Invalid header format. Please use KEY=VALUE format.")
# SSL context configuration
ssl_context: ssl.SSLContext | bool | None = None
if args.insecure:
# Disable SSL certificate verification
ssl_context = False
elif "https://" in base_url:
# Use default SSL context for HTTPS
ssl_context = True
# Fetch model from server if not specified # Fetch model from server if not specified
if args.model is None: if args.model is None:
print("Model not specified, fetching first model from server...") print("Model not specified, fetching first model from server...")
model_name, model_id = await get_first_model_from_server(base_url, headers) model_name, model_id = await get_first_model_from_server(
base_url, headers, ssl_context
)
print(f"First model name: {model_name}, first model id: {model_id}") print(f"First model name: {model_name}, first model id: {model_id}")
else: else:
model_name = args.served_model_name model_name = args.served_model_name
@@ -1691,6 +1717,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
ramp_up_start_rps=args.ramp_up_start_rps, ramp_up_start_rps=args.ramp_up_start_rps,
ramp_up_end_rps=args.ramp_up_end_rps, ramp_up_end_rps=args.ramp_up_end_rps,
ready_check_timeout_sec=args.ready_check_timeout_sec, ready_check_timeout_sec=args.ready_check_timeout_sec,
ssl_context=ssl_context,
) )
# Save config and results to json # Save config and results to json