add --insecure arg to the vllm bench to skip TLS (#34026)
Signed-off-by: Fan Yang <yan9fan@meta.com> Co-authored-by: Fan Yang <yan9fan@meta.com>
This commit is contained in:
@@ -1,15 +1,76 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import subprocess
|
import subprocess
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
|
import urllib3
|
||||||
|
|
||||||
from ..utils import RemoteOpenAIServer
|
from ..utils import RemoteOpenAIServer
|
||||||
|
|
||||||
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
def generate_self_signed_cert(cert_dir: Path) -> tuple[Path, Path]:
|
||||||
|
"""Generate a self-signed certificate for testing."""
|
||||||
|
cert_file = cert_dir / "cert.pem"
|
||||||
|
key_file = cert_dir / "key.pem"
|
||||||
|
|
||||||
|
# Generate self-signed certificate using openssl
|
||||||
|
subprocess.run(
|
||||||
|
[
|
||||||
|
"openssl",
|
||||||
|
"req",
|
||||||
|
"-x509",
|
||||||
|
"-newkey",
|
||||||
|
"rsa:2048",
|
||||||
|
"-keyout",
|
||||||
|
str(key_file),
|
||||||
|
"-out",
|
||||||
|
str(cert_file),
|
||||||
|
"-days",
|
||||||
|
"1",
|
||||||
|
"-nodes",
|
||||||
|
"-subj",
|
||||||
|
"/CN=localhost",
|
||||||
|
],
|
||||||
|
check=True,
|
||||||
|
capture_output=True,
|
||||||
|
)
|
||||||
|
return cert_file, key_file
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteOpenAIServerSSL(RemoteOpenAIServer):
|
||||||
|
"""RemoteOpenAIServer subclass that supports SSL with self-signed certs."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def url_root(self) -> str:
|
||||||
|
return f"https://{self.host}:{self.port}"
|
||||||
|
|
||||||
|
def _wait_for_server(self, *, url: str, timeout: float):
|
||||||
|
"""Override to use HTTPS with SSL verification disabled."""
|
||||||
|
# Suppress InsecureRequestWarning for self-signed certs
|
||||||
|
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
|
||||||
|
|
||||||
|
start = time.time()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if requests.get(url, verify=False).status_code == 200:
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
result = self._poll()
|
||||||
|
if result is not None and result != 0:
|
||||||
|
raise RuntimeError("Server exited unexpectedly.") from None
|
||||||
|
|
||||||
|
time.sleep(0.5)
|
||||||
|
if time.time() - start > timeout:
|
||||||
|
raise RuntimeError("Server failed to start in time.") from None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
def server():
|
def server():
|
||||||
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
|
args = ["--max-model-len", "1024", "--enforce-eager", "--load-format", "dummy"]
|
||||||
|
|
||||||
@@ -17,6 +78,27 @@ def server():
|
|||||||
yield remote_server
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def ssl_server():
|
||||||
|
"""Start a vLLM server with SSL enabled using a self-signed certificate."""
|
||||||
|
with tempfile.TemporaryDirectory() as cert_dir:
|
||||||
|
cert_file, key_file = generate_self_signed_cert(Path(cert_dir))
|
||||||
|
args = [
|
||||||
|
"--max-model-len",
|
||||||
|
"1024",
|
||||||
|
"--enforce-eager",
|
||||||
|
"--load-format",
|
||||||
|
"dummy",
|
||||||
|
"--ssl-certfile",
|
||||||
|
str(cert_file),
|
||||||
|
"--ssl-keyfile",
|
||||||
|
str(key_file),
|
||||||
|
]
|
||||||
|
|
||||||
|
with RemoteOpenAIServerSSL(MODEL_NAME, args) as remote_server:
|
||||||
|
yield remote_server
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_serve(server):
|
def test_bench_serve(server):
|
||||||
# Test default model detection and input/output len
|
# Test default model detection and input/output len
|
||||||
@@ -42,6 +124,31 @@ def test_bench_serve(server):
|
|||||||
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.benchmark
|
||||||
|
def test_bench_serve_insecure(ssl_server):
|
||||||
|
"""Test --insecure flag with an HTTPS server using a self-signed certificate."""
|
||||||
|
base_url = f"https://{ssl_server.host}:{ssl_server.port}"
|
||||||
|
command = [
|
||||||
|
"vllm",
|
||||||
|
"bench",
|
||||||
|
"serve",
|
||||||
|
"--base-url",
|
||||||
|
base_url,
|
||||||
|
"--input-len",
|
||||||
|
"32",
|
||||||
|
"--output-len",
|
||||||
|
"4",
|
||||||
|
"--num-prompts",
|
||||||
|
"5",
|
||||||
|
"--insecure",
|
||||||
|
]
|
||||||
|
result = subprocess.run(command, capture_output=True, text=True)
|
||||||
|
print(result.stdout)
|
||||||
|
print(result.stderr)
|
||||||
|
|
||||||
|
assert result.returncode == 0, f"Benchmark failed: {result.stderr}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.benchmark
|
@pytest.mark.benchmark
|
||||||
def test_bench_serve_chat(server):
|
def test_bench_serve_chat(server):
|
||||||
command = [
|
command = [
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import shutil
|
import shutil
|
||||||
|
import ssl
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
@@ -60,11 +61,14 @@ TERM_PLOTLIB_AVAILABLE = (importlib.util.find_spec("termplotlib") is not None) a
|
|||||||
|
|
||||||
|
|
||||||
async def get_first_model_from_server(
|
async def get_first_model_from_server(
|
||||||
base_url: str, headers: dict | None = None
|
base_url: str,
|
||||||
|
headers: dict | None = None,
|
||||||
|
ssl_context: ssl.SSLContext | bool | None = None,
|
||||||
) -> tuple[str, str]:
|
) -> tuple[str, str]:
|
||||||
"""Fetch the first model from the server's /v1/models endpoint."""
|
"""Fetch the first model from the server's /v1/models endpoint."""
|
||||||
models_url = f"{base_url}/v1/models"
|
models_url = f"{base_url}/v1/models"
|
||||||
async with aiohttp.ClientSession() as session:
|
connector = aiohttp.TCPConnector(ssl=ssl_context)
|
||||||
|
async with aiohttp.ClientSession(connector=connector) as session:
|
||||||
try:
|
try:
|
||||||
async with session.get(models_url, headers=headers) as response:
|
async with session.get(models_url, headers=headers) as response:
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
@@ -619,6 +623,7 @@ async def benchmark(
|
|||||||
ramp_up_start_rps: int | None = None,
|
ramp_up_start_rps: int | None = None,
|
||||||
ramp_up_end_rps: int | None = None,
|
ramp_up_end_rps: int | None = None,
|
||||||
ready_check_timeout_sec: int = 600,
|
ready_check_timeout_sec: int = 600,
|
||||||
|
ssl_context: ssl.SSLContext | bool | None = None,
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
request_func = ASYNC_REQUEST_FUNCS[endpoint_type]
|
||||||
@@ -626,6 +631,8 @@ async def benchmark(
|
|||||||
raise ValueError(f"Unknown backend: {endpoint_type}") from None
|
raise ValueError(f"Unknown backend: {endpoint_type}") from None
|
||||||
|
|
||||||
# Reuses connections across requests to reduce TLS handshake overhead.
|
# Reuses connections across requests to reduce TLS handshake overhead.
|
||||||
|
# Use ssl_context if provided, otherwise default to True for https URLs
|
||||||
|
ssl_setting = ssl_context if ssl_context is not None else ("https://" in api_url)
|
||||||
connector = aiohttp.TCPConnector(
|
connector = aiohttp.TCPConnector(
|
||||||
limit=max_concurrency or 0,
|
limit=max_concurrency or 0,
|
||||||
limit_per_host=max_concurrency or 0,
|
limit_per_host=max_concurrency or 0,
|
||||||
@@ -634,7 +641,7 @@ async def benchmark(
|
|||||||
keepalive_timeout=60,
|
keepalive_timeout=60,
|
||||||
enable_cleanup_closed=True,
|
enable_cleanup_closed=True,
|
||||||
force_close=False,
|
force_close=False,
|
||||||
ssl=("https://" in api_url),
|
ssl=ssl_setting,
|
||||||
)
|
)
|
||||||
|
|
||||||
session = aiohttp.ClientSession(
|
session = aiohttp.ClientSession(
|
||||||
@@ -1513,6 +1520,14 @@ def add_cli_args(parser: argparse.ArgumentParser):
|
|||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--insecure",
|
||||||
|
action="store_true",
|
||||||
|
default=False,
|
||||||
|
help="Disable SSL certificate verification. Use this option when "
|
||||||
|
"connecting to servers with self-signed certificates.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def main(args: argparse.Namespace) -> dict[str, Any]:
|
def main(args: argparse.Namespace) -> dict[str, Any]:
|
||||||
return asyncio.run(main_async(args))
|
return asyncio.run(main_async(args))
|
||||||
@@ -1564,10 +1579,21 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
else:
|
else:
|
||||||
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
|
raise ValueError("Invalid header format. Please use KEY=VALUE format.")
|
||||||
|
|
||||||
|
# SSL context configuration
|
||||||
|
ssl_context: ssl.SSLContext | bool | None = None
|
||||||
|
if args.insecure:
|
||||||
|
# Disable SSL certificate verification
|
||||||
|
ssl_context = False
|
||||||
|
elif "https://" in base_url:
|
||||||
|
# Use default SSL context for HTTPS
|
||||||
|
ssl_context = True
|
||||||
|
|
||||||
# Fetch model from server if not specified
|
# Fetch model from server if not specified
|
||||||
if args.model is None:
|
if args.model is None:
|
||||||
print("Model not specified, fetching first model from server...")
|
print("Model not specified, fetching first model from server...")
|
||||||
model_name, model_id = await get_first_model_from_server(base_url, headers)
|
model_name, model_id = await get_first_model_from_server(
|
||||||
|
base_url, headers, ssl_context
|
||||||
|
)
|
||||||
print(f"First model name: {model_name}, first model id: {model_id}")
|
print(f"First model name: {model_name}, first model id: {model_id}")
|
||||||
else:
|
else:
|
||||||
model_name = args.served_model_name
|
model_name = args.served_model_name
|
||||||
@@ -1691,6 +1717,7 @@ async def main_async(args: argparse.Namespace) -> dict[str, Any]:
|
|||||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||||
ready_check_timeout_sec=args.ready_check_timeout_sec,
|
ready_check_timeout_sec=args.ready_check_timeout_sec,
|
||||||
|
ssl_context=ssl_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save config and results to json
|
# Save config and results to json
|
||||||
|
|||||||
Reference in New Issue
Block a user