Use aiohttp connection pool for benchmarking (#21981)
Signed-off-by: Seiji Eicher <seiji@anyscale.com>
This commit is contained in:
@@ -50,6 +50,7 @@ class RequestFuncOutput:
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
"""The async request function for the OpenAI Completions API.
|
||||
@@ -66,8 +67,6 @@ async def async_request_openai_completions(
|
||||
("completions", "profile")
|
||||
), "OpenAI Completions API URL must end with 'completions' or 'profile'."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model_name \
|
||||
if request_func_input.model_name else request_func_input.model,
|
||||
@@ -164,14 +163,13 @@ async def async_request_openai_completions(
|
||||
|
||||
async def async_request_openai_chat_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
api_url = request_func_input.api_url
|
||||
assert api_url.endswith(("chat/completions", "profile")), (
|
||||
"OpenAI Chat Completions API URL must end with 'chat/completions'.")
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
if request_func_input.multi_modal_content:
|
||||
content.append(request_func_input.multi_modal_content)
|
||||
@@ -269,6 +267,7 @@ async def async_request_openai_chat_completions(
|
||||
|
||||
async def async_request_openai_audio(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
# Lazy import without PlaceholderModule to avoid vllm dep.
|
||||
@@ -279,8 +278,6 @@ async def async_request_openai_audio(
|
||||
"OpenAI Chat Completions API URL must end with 'transcriptions' ")
|
||||
"or `translations`."
|
||||
|
||||
async with aiohttp.ClientSession(trust_env=True,
|
||||
timeout=AIOHTTP_TIMEOUT) as session:
|
||||
content = [{"type": "text", "text": request_func_input.prompt}]
|
||||
payload = {
|
||||
"model":
|
||||
|
||||
@@ -14,6 +14,7 @@ from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
|
||||
async def wait_for_endpoint(
|
||||
request_func,
|
||||
test_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
timeout_seconds: int = 600,
|
||||
retry_interval: int = 5,
|
||||
) -> RequestFuncOutput:
|
||||
@@ -55,7 +56,8 @@ async def wait_for_endpoint(
|
||||
|
||||
# ping the endpoint using request_func
|
||||
try:
|
||||
output = await request_func(request_func_input=test_input)
|
||||
output = await request_func(
|
||||
request_func_input=test_input, session=session)
|
||||
if output.success:
|
||||
pbar.close()
|
||||
return output
|
||||
|
||||
@@ -28,6 +28,7 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -338,6 +339,24 @@ async def benchmark(
|
||||
else:
|
||||
raise ValueError(f"Unknown endpoint_type: {endpoint_type}")
|
||||
|
||||
# Reuses connections across requests to reduce TLS handshake overhead.
|
||||
connector = aiohttp.TCPConnector(
|
||||
limit=max_concurrency or 0,
|
||||
limit_per_host=max_concurrency or 0,
|
||||
ttl_dns_cache=300,
|
||||
use_dns_cache=True,
|
||||
keepalive_timeout=60,
|
||||
enable_cleanup_closed=True,
|
||||
force_close=False,
|
||||
ssl=("https://" in api_url),
|
||||
)
|
||||
|
||||
session = aiohttp.ClientSession(
|
||||
connector=connector,
|
||||
trust_env=True,
|
||||
timeout=aiohttp.ClientTimeout(total=6 * 60 * 60),
|
||||
)
|
||||
|
||||
print("Starting initial single prompt test run...")
|
||||
test_prompt, test_prompt_len, test_output_len, test_mm_content = (
|
||||
input_requests[0].prompt,
|
||||
@@ -361,7 +380,11 @@ async def benchmark(
|
||||
)
|
||||
|
||||
test_output = await wait_for_endpoint(
|
||||
request_func, test_input, timeout_seconds=ready_check_timeout_sec)
|
||||
request_func,
|
||||
test_input,
|
||||
session,
|
||||
timeout_seconds=ready_check_timeout_sec,
|
||||
)
|
||||
if not test_output.success:
|
||||
raise ValueError(
|
||||
"Initial test run failed - Please make sure benchmark arguments "
|
||||
@@ -386,7 +409,8 @@ async def benchmark(
|
||||
multi_modal_content=test_mm_content,
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
profile_output = await request_func(
|
||||
request_func_input=profile_input, session=session)
|
||||
if profile_output.success:
|
||||
print("Profiler started")
|
||||
|
||||
@@ -412,12 +436,14 @@ async def benchmark(
|
||||
semaphore = (asyncio.Semaphore(max_concurrency)
|
||||
if max_concurrency else None)
|
||||
|
||||
async def limited_request_func(request_func_input, pbar):
|
||||
async def limited_request_func(request_func_input, session, pbar):
|
||||
if semaphore is None:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
session=session,
|
||||
pbar=pbar)
|
||||
async with semaphore:
|
||||
return await request_func(request_func_input=request_func_input,
|
||||
session=session,
|
||||
pbar=pbar)
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
@@ -469,6 +495,7 @@ async def benchmark(
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input,
|
||||
session=session,
|
||||
pbar=pbar)))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
@@ -580,9 +607,12 @@ async def benchmark(
|
||||
output_len=test_output_len,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
profile_output = await request_func(request_func_input=profile_input)
|
||||
profile_output = await request_func(
|
||||
request_func_input=profile_input, session=session)
|
||||
if profile_output.success:
|
||||
print("Profiler stopped")
|
||||
|
||||
await session.close()
|
||||
return result
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user