Use aiohttp connection pool for benchmarking (#21981)

Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
Seiji Eicher
2025-08-03 19:23:32 -07:00
committed by GitHub
parent 6a39ba85fe
commit 6f5478298d
3 changed files with 271 additions and 242 deletions

View File

@@ -50,6 +50,7 @@ class RequestFuncOutput:
async def async_request_openai_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
"""The async request function for the OpenAI Completions API.
@@ -66,8 +67,6 @@ async def async_request_openai_completions(
("completions", "profile")
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
payload = {
"model": request_func_input.model_name \
if request_func_input.model_name else request_func_input.model,
@@ -164,14 +163,13 @@ async def async_request_openai_completions(
async def async_request_openai_chat_completions(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
api_url = request_func_input.api_url
assert api_url.endswith(("chat/completions", "profile")), (
"OpenAI Chat Completions API URL must end with 'chat/completions'.")
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
if request_func_input.multi_modal_content:
content.append(request_func_input.multi_modal_content)
@@ -269,6 +267,7 @@ async def async_request_openai_chat_completions(
async def async_request_openai_audio(
request_func_input: RequestFuncInput,
session: aiohttp.ClientSession,
pbar: Optional[tqdm] = None,
) -> RequestFuncOutput:
# Lazy import without PlaceholderModule to avoid vllm dep.
@@ -279,8 +278,6 @@ async def async_request_openai_audio(
"OpenAI Chat Completions API URL must end with 'transcriptions' ")
"or `translations`."
async with aiohttp.ClientSession(trust_env=True,
timeout=AIOHTTP_TIMEOUT) as session:
content = [{"type": "text", "text": request_func_input.prompt}]
payload = {
"model":

View File

@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
async def wait_for_endpoint(
request_func,
test_input: RequestFuncInput,
session: aiohttp.ClientSession,
timeout_seconds: int = 600,
retry_interval: int = 5,
) -> RequestFuncOutput:
@@ -55,7 +56,8 @@ async def wait_for_endpoint(
# ping the endpoint using request_func
try:
output = await request_func(request_func_input=test_input)
output = await request_func(
request_func_input=test_input, session=session)
if output.success:
pbar.close()
return output

View File

@@ -28,6 +28,7 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, Literal, Optional
import aiohttp
import numpy as np
from tqdm.asyncio import tqdm
from transformers import PreTrainedTokenizerBase
@@ -338,6 +339,24 @@ async def benchmark(
else:
raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
# Reuses connections across requests to reduce TLS handshake overhead.
connector = aiohttp.TCPConnector(
limit=max_concurrency or 0,
limit_per_host=max_concurrency or 0,
ttl_dns_cache=300,
use_dns_cache=True,
keepalive_timeout=60,
enable_cleanup_closed=True,
force_close=False,
ssl=("https://" in api_url),
)
session = aiohttp.ClientSession(
connector=connector,
trust_env=True,
timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
)
print("Starting initial single prompt test run...")
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
input_requests[0].prompt,
@@ -361,7 +380,11 @@ async def benchmark(
)
test_output = await wait_for_endpoint(
request_func, test_input, timeout_seconds=ready_check_timeout_sec)
request_func,
test_input,
session,
timeout_seconds=ready_check_timeout_sec,
)
if not test_output.success:
raise ValueError(
"Initial test run failed - Please make sure benchmark arguments "
@@ -386,7 +409,8 @@ async def benchmark(
multi_modal_content=test_mm_content,
ignore_eos=ignore_eos,
extra_body=extra_body)
profile_output = await request_func(request_func_input=profile_input)
profile_output = await request_func(
request_func_input=profile_input, session=session)
if profile_output.success:
print("Profiler started")
@@ -412,12 +436,14 @@ async def benchmark(
semaphore = (asyncio.Semaphore(max_concurrency)
if max_concurrency else None)
async def limited_request_func(request_func_input, pbar):
async def limited_request_func(request_func_input, session, pbar):
if semaphore is None:
return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar)
async with semaphore:
return await request_func(request_func_input=request_func_input,
session=session,
pbar=pbar)
benchmark_start_time = time.perf_counter()
@@ -469,6 +495,7 @@ async def benchmark(
tasks.append(
asyncio.create_task(
limited_request_func(request_func_input=request_func_input,
session=session,
pbar=pbar)))
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
@@ -580,9 +607,12 @@ async def benchmark(
output_len=test_output_len,
logprobs=logprobs,
)
profile_output = await request_func(request_func_input=profile_input)
profile_output = await request_func(
request_func_input=profile_input, session=session)
if profile_output.success:
print("Profiler stopped")
await session.close()
return result