[Feature] Add --distributed-timeout-seconds CLI option (#36047)

Signed-off-by: Shiyan Deng <dsy842974287@meta.com>
Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
Shiyan Deng
2026-03-05 20:57:51 -08:00
committed by GitHub
parent 8e87cc57f1
commit 03a49bb8f0
3 changed files with 25 additions and 2 deletions

View File

@@ -234,9 +234,15 @@ class ParallelConfig:
"""distributed node rank for multi-node distributed
inference when distributed_executor_backend is mp."""
nnodes: int = 1
"""num of nodes for multi-node distributed
"""num of nodes for multi-node distributed
inference when distributed_executor_backend is mp."""
distributed_timeout_seconds: int | None = None
"""Timeout in seconds for distributed operations (e.g., init_process_group).
If set, this value is passed to torch.distributed.init_process_group as the
timeout parameter. If None, PyTorch's default timeout is used (600s for NCCL).
Increase this for multi-node setups where model downloads may be slow."""
world_size: int = Field(init=False)
"""world_size is TPxPP, it affects the number of workers we create."""

View File

@@ -403,6 +403,7 @@ class EngineArgs:
master_port: int = ParallelConfig.master_port
nnodes: int = ParallelConfig.nnodes
node_rank: int = ParallelConfig.node_rank
distributed_timeout_seconds: int | None = ParallelConfig.distributed_timeout_seconds
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
@@ -814,6 +815,10 @@ class EngineArgs:
parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
parallel_group.add_argument(
"--distributed-timeout-seconds",
**parallel_kwargs["distributed_timeout_seconds"],
)
parallel_group.add_argument(
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
)
@@ -1701,6 +1706,7 @@ class EngineArgs:
master_port=self.master_port,
nnodes=self.nnodes,
node_rank=self.node_rank,
distributed_timeout_seconds=self.distributed_timeout_seconds,
data_parallel_master_ip=data_parallel_address,
data_parallel_rpc_port=data_parallel_rpc_port,
data_parallel_backend=self.data_parallel_backend,

View File

@@ -6,6 +6,7 @@ import gc
import os
from collections.abc import Callable
from contextlib import AbstractContextManager, nullcontext
from datetime import timedelta
from types import NoneType
from typing import TYPE_CHECKING, Any
@@ -942,8 +943,18 @@ def init_worker_distributed_environment(
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_method = distributed_init_method or "env://"
timeout = None
if parallel_config.distributed_timeout_seconds is not None:
timeout = timedelta(seconds=parallel_config.distributed_timeout_seconds)
init_distributed_environment(
parallel_config.world_size, rank, init_method, local_rank, backend
parallel_config.world_size,
rank,
init_method,
local_rank,
backend,
timeout,
)
ensure_model_parallel_initialized(