Add request timeout override for multi-turn benchmarks (#28386)
Signed-off-by: Ido Segev <idos@pliops.com>
This commit is contained in:
@@ -63,6 +63,7 @@ class RequestArgs(NamedTuple):
|
|||||||
stream: bool
|
stream: bool
|
||||||
limit_min_tokens: int # Use negative value for no limit
|
limit_min_tokens: int # Use negative value for no limit
|
||||||
limit_max_tokens: int # Use negative value for no limit
|
limit_max_tokens: int # Use negative value for no limit
|
||||||
|
timeout_sec: int
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkArgs(NamedTuple):
|
class BenchmarkArgs(NamedTuple):
|
||||||
@@ -214,6 +215,7 @@ async def send_request(
|
|||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
min_tokens: int | None = None,
|
min_tokens: int | None = None,
|
||||||
max_tokens: int | None = None,
|
max_tokens: int | None = None,
|
||||||
|
timeout_sec: int = 120,
|
||||||
) -> ServerResponse:
|
) -> ServerResponse:
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -235,10 +237,16 @@ async def send_request(
|
|||||||
headers = {"Content-Type": "application/json"}
|
headers = {"Content-Type": "application/json"}
|
||||||
|
|
||||||
# Calculate the timeout for the request
|
# Calculate the timeout for the request
|
||||||
timeout_sec = 120
|
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
# Assume TPOT of 200ms and use max_tokens to determine timeout
|
# Assume TPOT of 200ms and use max_tokens to determine timeout
|
||||||
timeout_sec = max(timeout_sec, int(max_tokens * 0.2))
|
token_based_timeout = int(max_tokens * 0.2)
|
||||||
|
if token_based_timeout > timeout_sec:
|
||||||
|
timeout_sec = token_based_timeout
|
||||||
|
logger.info(
|
||||||
|
"Using timeout of %ds based on max_tokens %d",
|
||||||
|
timeout_sec,
|
||||||
|
max_tokens,
|
||||||
|
)
|
||||||
timeout = aiohttp.ClientTimeout(total=timeout_sec)
|
timeout = aiohttp.ClientTimeout(total=timeout_sec)
|
||||||
|
|
||||||
valid_response = True
|
valid_response = True
|
||||||
@@ -409,6 +417,7 @@ async def send_turn(
|
|||||||
req_args.stream,
|
req_args.stream,
|
||||||
min_tokens,
|
min_tokens,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
|
req_args.timeout_sec,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.valid is False:
|
if response.valid is False:
|
||||||
@@ -676,8 +685,18 @@ async def client_main(
|
|||||||
|
|
||||||
except asyncio.exceptions.TimeoutError:
|
except asyncio.exceptions.TimeoutError:
|
||||||
num_failures += 1
|
num_failures += 1
|
||||||
logger.exception(
|
logger.error(
|
||||||
f"{Color.RED}Client {client_id} - Timeout during conversation ID {conv_id} (turn: {current_turn}){Color.RESET}" # noqa: E501
|
"%sClient %d - Timeout during conversation ID %s (turn: %d). "
|
||||||
|
"Base timeout is %ss (set with --request-timeout-sec), but the "
|
||||||
|
"effective timeout may be longer based on max_tokens. If this "
|
||||||
|
"is unexpected, consider increasing the timeout or checking "
|
||||||
|
"model performance.%s",
|
||||||
|
Color.RED,
|
||||||
|
client_id,
|
||||||
|
conv_id,
|
||||||
|
current_turn,
|
||||||
|
req_args.timeout_sec,
|
||||||
|
Color.RESET,
|
||||||
)
|
)
|
||||||
break # Exit gracefully instead of raising an error
|
break # Exit gracefully instead of raising an error
|
||||||
|
|
||||||
@@ -815,6 +834,9 @@ def get_client_config(
|
|||||||
"Invalid min/max tokens limits (min should not be larger than max)"
|
"Invalid min/max tokens limits (min should not be larger than max)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.request_timeout_sec <= 0:
|
||||||
|
raise ValueError("Request timeout must be a positive number")
|
||||||
|
|
||||||
# Arguments for API requests
|
# Arguments for API requests
|
||||||
chat_url = f"{args.url}/v1/chat/completions"
|
chat_url = f"{args.url}/v1/chat/completions"
|
||||||
model_name = args.served_model_name if args.served_model_name else args.model
|
model_name = args.served_model_name if args.served_model_name else args.model
|
||||||
@@ -825,6 +847,7 @@ def get_client_config(
|
|||||||
stream=not args.no_stream,
|
stream=not args.no_stream,
|
||||||
limit_min_tokens=args.limit_min_tokens,
|
limit_min_tokens=args.limit_min_tokens,
|
||||||
limit_max_tokens=args.limit_max_tokens,
|
limit_max_tokens=args.limit_max_tokens,
|
||||||
|
timeout_sec=args.request_timeout_sec,
|
||||||
)
|
)
|
||||||
|
|
||||||
return client_args, req_args
|
return client_args, req_args
|
||||||
@@ -968,7 +991,7 @@ async def main_mp(
|
|||||||
f"(is alive: {client.is_alive()}){Color.RESET}"
|
f"(is alive: {client.is_alive()}){Color.RESET}"
|
||||||
)
|
)
|
||||||
|
|
||||||
client.join(timeout=120)
|
client.join(timeout=req_args.timeout_sec + 1)
|
||||||
|
|
||||||
if client.is_alive():
|
if client.is_alive():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1351,6 +1374,13 @@ async def main() -> None:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Verify the LLM output (compare to the answers in the input JSON file)",
|
help="Verify the LLM output (compare to the answers in the input JSON file)",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--request-timeout-sec",
|
||||||
|
type=int,
|
||||||
|
default=120,
|
||||||
|
help="Timeout in seconds for each API request (default: 120). "
|
||||||
|
"Automatically increased if max tokens imply longer decoding.",
|
||||||
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--no-stream",
|
"--no-stream",
|
||||||
|
|||||||
Reference in New Issue
Block a user