diff --git a/tests/v1/test_hybrid_lb_dp.py b/tests/v1/test_hybrid_lb_dp.py index 08336489a..74708b617 100644 --- a/tests/v1/test_hybrid_lb_dp.py +++ b/tests/v1/test_hybrid_lb_dp.py @@ -147,7 +147,7 @@ def default_server_args(): ] -@pytest.fixture(scope="module", params=[1]) # Only 1 API server for now +@pytest.fixture(scope="module", params=[1, 4]) def servers(request, default_server_args): api_server_count = request.param with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count, diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index b144431de..68eb25809 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -165,18 +165,14 @@ def run_multi_api_server(args: argparse.Namespace): " api_server_count > 1") model_config.disable_mm_preprocessor_cache = True - if vllm_config.parallel_config.data_parallel_hybrid_lb: - raise NotImplementedError( - "Hybrid load balancing with --api-server-count > 0" - "is not yet supported.") - executor_class = Executor.get_class(vllm_config) log_stats = not engine_args.disable_log_stats parallel_config = vllm_config.parallel_config dp_rank = parallel_config.data_parallel_rank external_dp_lb = parallel_config.data_parallel_external_lb - assert external_dp_lb or dp_rank == 0 + hybrid_dp_lb = parallel_config.data_parallel_hybrid_lb + assert external_dp_lb or hybrid_dp_lb or dp_rank == 0 api_server_manager: Optional[APIServerProcessManager] = None @@ -196,12 +192,12 @@ def run_multi_api_server(args: argparse.Namespace): stats_update_address=coordinator.get_stats_publish_address() if coordinator else None) - # For dp ranks > 0 in external DP LB mode, we must delay the + # For dp ranks > 0 in external/hybrid DP LB modes, we must delay the # start of the API servers until the local engine is started # (after the launcher context manager exits), # since we get the front-end stats update address from the coordinator # via the handshake with the local engine. - if dp_rank == 0 or not external_dp_lb: + if dp_rank == 0 or not (external_dp_lb or hybrid_dp_lb): # Start API servers using the manager. api_server_manager = APIServerProcessManager( **api_server_manager_kwargs)