[Feature] Add --distributed-timeout-seconds CLI option (#36047)
Signed-off-by: Shiyan Deng <dsy842974287@meta.com> Co-authored-by: Lu Fang <30275821+houseroad@users.noreply.github.com>
This commit is contained in:
@@ -234,9 +234,15 @@ class ParallelConfig:
|
||||
"""distributed node rank for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
nnodes: int = 1
|
||||
"""num of nodes for multi-node distributed
|
||||
"""num of nodes for multi-node distributed
|
||||
inference when distributed_executor_backend is mp."""
|
||||
|
||||
distributed_timeout_seconds: int | None = None
|
||||
"""Timeout in seconds for distributed operations (e.g., init_process_group).
|
||||
If set, this value is passed to torch.distributed.init_process_group as the
|
||||
timeout parameter. If None, PyTorch's default timeout is used (600s for NCCL).
|
||||
Increase this for multi-node setups where model downloads may be slow."""
|
||||
|
||||
world_size: int = Field(init=False)
|
||||
"""world_size is TPxPP, it affects the number of workers we create."""
|
||||
|
||||
|
||||
@@ -403,6 +403,7 @@ class EngineArgs:
|
||||
master_port: int = ParallelConfig.master_port
|
||||
nnodes: int = ParallelConfig.nnodes
|
||||
node_rank: int = ParallelConfig.node_rank
|
||||
distributed_timeout_seconds: int | None = ParallelConfig.distributed_timeout_seconds
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
prefill_context_parallel_size: int = ParallelConfig.prefill_context_parallel_size
|
||||
decode_context_parallel_size: int = ParallelConfig.decode_context_parallel_size
|
||||
@@ -814,6 +815,10 @@ class EngineArgs:
|
||||
parallel_group.add_argument("--master-port", **parallel_kwargs["master_port"])
|
||||
parallel_group.add_argument("--nnodes", "-n", **parallel_kwargs["nnodes"])
|
||||
parallel_group.add_argument("--node-rank", "-r", **parallel_kwargs["node_rank"])
|
||||
parallel_group.add_argument(
|
||||
"--distributed-timeout-seconds",
|
||||
**parallel_kwargs["distributed_timeout_seconds"],
|
||||
)
|
||||
parallel_group.add_argument(
|
||||
"--tensor-parallel-size", "-tp", **parallel_kwargs["tensor_parallel_size"]
|
||||
)
|
||||
@@ -1701,6 +1706,7 @@ class EngineArgs:
|
||||
master_port=self.master_port,
|
||||
nnodes=self.nnodes,
|
||||
node_rank=self.node_rank,
|
||||
distributed_timeout_seconds=self.distributed_timeout_seconds,
|
||||
data_parallel_master_ip=data_parallel_address,
|
||||
data_parallel_rpc_port=data_parallel_rpc_port,
|
||||
data_parallel_backend=self.data_parallel_backend,
|
||||
|
||||
@@ -6,6 +6,7 @@ import gc
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from datetime import timedelta
|
||||
from types import NoneType
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -942,8 +943,18 @@ def init_worker_distributed_environment(
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
|
||||
init_method = distributed_init_method or "env://"
|
||||
|
||||
timeout = None
|
||||
if parallel_config.distributed_timeout_seconds is not None:
|
||||
timeout = timedelta(seconds=parallel_config.distributed_timeout_seconds)
|
||||
|
||||
init_distributed_environment(
|
||||
parallel_config.world_size, rank, init_method, local_rank, backend
|
||||
parallel_config.world_size,
|
||||
rank,
|
||||
init_method,
|
||||
local_rank,
|
||||
backend,
|
||||
timeout,
|
||||
)
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
|
||||
Reference in New Issue
Block a user