[V1] AsyncLLM data parallel (#13923)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-03-27 16:14:41 -07:00
committed by GitHub
parent 112b3e5b3b
commit 15dac210f0
18 changed files with 722 additions and 156 deletions

View File

@@ -114,6 +114,7 @@ class EngineArgs:
# number of P/D disaggregation (or other disaggregation) workers
pipeline_parallel_size: int = 1
tensor_parallel_size: int = 1
data_parallel_size: int = 1
enable_expert_parallel: bool = False
max_parallel_loading_workers: Optional[int] = None
block_size: Optional[int] = None
@@ -442,6 +443,14 @@ class EngineArgs:
type=int,
default=EngineArgs.tensor_parallel_size,
help='Number of tensor parallel replicas.')
parser.add_argument('--data-parallel-size',
'-dp',
type=int,
default=EngineArgs.data_parallel_size,
help='Number of data parallel replicas. '
'MoE layers will be sharded according to the '
'product of the tensor-parallel-size and '
'data-parallel-size.')
parser.add_argument(
'--enable-expert-parallel',
action='store_true',
@@ -1359,6 +1368,7 @@ class EngineArgs:
parallel_config = ParallelConfig(
pipeline_parallel_size=self.pipeline_parallel_size,
tensor_parallel_size=self.tensor_parallel_size,
data_parallel_size=self.data_parallel_size,
enable_expert_parallel=self.enable_expert_parallel,
max_parallel_loading_workers=self.max_parallel_loading_workers,
disable_custom_all_reduce=self.disable_custom_all_reduce,