[UX] Default api_server_count to dp_size if not specified (#32525)

Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Tyler Michael Smith
2026-01-22 12:35:35 -05:00
committed by GitHub
parent 70917b1c55
commit 803e3f3f68
2 changed files with 61 additions and 9 deletions

View File

@@ -50,14 +50,65 @@ class ServeSubcommand(CLISubcommand):
if hasattr(args, "model_tag") and args.model_tag is not None: if hasattr(args, "model_tag") and args.model_tag is not None:
args.model = args.model_tag args.model = args.model_tag
if args.headless or args.api_server_count < 1: if args.headless:
run_headless(args) if args.api_server_count is not None and args.api_server_count > 0:
else: raise ValueError(
if args.api_server_count > 1: f"--api-server-count={args.api_server_count} cannot be "
run_multi_api_server(args) "used with --headless (no API servers are started in "
"headless mode)."
)
# Default to 0 in headless mode (no API servers)
args.api_server_count = 0
# Detect LB mode for defaulting api_server_count.
# External LB: --data-parallel-external-lb or --data-parallel-rank
# Hybrid LB: --data-parallel-hybrid-lb or --data-parallel-start-rank
is_external_lb = (
args.data_parallel_external_lb or args.data_parallel_rank is not None
)
is_hybrid_lb = (
args.data_parallel_hybrid_lb or args.data_parallel_start_rank is not None
)
if is_external_lb and is_hybrid_lb:
raise ValueError(
"Cannot use both external and hybrid data parallel load "
"balancing modes. External LB is enabled via "
"--data-parallel-external-lb or --data-parallel-rank. "
"Hybrid LB is enabled via --data-parallel-hybrid-lb or "
"--data-parallel-start-rank. Use one mode or the other."
)
# Default api_server_count if not explicitly set.
# - External LB: Leave as 1 (external LB handles distribution)
# - Hybrid LB: Use local DP size (internal LB for local ranks only)
# - Internal LB: Use full DP size
if args.api_server_count is None:
if is_external_lb:
args.api_server_count = 1
elif is_hybrid_lb:
args.api_server_count = args.data_parallel_size_local or 1
if args.api_server_count > 1:
logger.info(
"Defaulting api_server_count to data_parallel_size_local "
"(%d) for hybrid LB mode.",
args.api_server_count,
)
else: else:
# Single API server (this process). args.api_server_count = args.data_parallel_size
uvloop.run(run_server(args)) if args.api_server_count > 1:
logger.info(
"Defaulting api_server_count to data_parallel_size (%d).",
args.api_server_count,
)
if args.api_server_count < 1:
run_headless(args)
elif args.api_server_count > 1:
run_multi_api_server(args)
else:
# Single API server (this process).
uvloop.run(run_server(args))
def validate(self, args: argparse.Namespace) -> None: def validate(self, args: argparse.Namespace) -> None:
validate_parsed_serve_args(args) validate_parsed_serve_args(args)

View File

@@ -283,8 +283,9 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--api-server-count", "--api-server-count",
"-asc", "-asc",
type=int, type=int,
default=1, default=None,
help="How many API server processes to run.", help="How many API server processes to run. "
"Defaults to data_parallel_size if not specified.",
) )
parser.add_argument( parser.add_argument(
"--config", "--config",