[Feature] Support Decode Context Parallel (DCP) for MLA (#23734)
Signed-off-by: hongchao <hongchao@msh.team> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: hongchao <hongchao@msh.team> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -306,6 +306,8 @@ class EngineArgs:
|
||||
# number of P/D disaggregation (or other disaggregation) workers
|
||||
pipeline_parallel_size: int = ParallelConfig.pipeline_parallel_size
|
||||
tensor_parallel_size: int = ParallelConfig.tensor_parallel_size
|
||||
decode_context_parallel_size: int = \
|
||||
ParallelConfig.decode_context_parallel_size
|
||||
data_parallel_size: int = ParallelConfig.data_parallel_size
|
||||
data_parallel_rank: Optional[int] = None
|
||||
data_parallel_start_rank: Optional[int] = None
|
||||
@@ -636,6 +638,9 @@ class EngineArgs:
|
||||
**parallel_kwargs["pipeline_parallel_size"])
|
||||
parallel_group.add_argument("--tensor-parallel-size", "-tp",
|
||||
**parallel_kwargs["tensor_parallel_size"])
|
||||
parallel_group.add_argument(
|
||||
"--decode-context-parallel-size", "-dcp",
|
||||
**parallel_kwargs["decode_context_parallel_size"])
|
||||
parallel_group.add_argument("--data-parallel-size", "-dp",
|
||||
**parallel_kwargs["data_parallel_size"])
|
||||
parallel_group.add_argument(
|
||||
@@ -1156,6 +1161,17 @@ class EngineArgs:
|
||||
# global layers in interleaved sliding window models.
|
||||
sliding_window = model_config.get_sliding_window()
|
||||
|
||||
# Note(hc): In the current implementation of decode context
|
||||
# parallel(DCP), tp_size needs to be divisible by dcp_size,
|
||||
# because the world size does not change by dcp, it simply
|
||||
# reuse the GPUs of TP group, and split one TP group into
|
||||
# tp_size//dcp_size DCP groups.
|
||||
assert self.tensor_parallel_size % self.decode_context_parallel_size \
|
||||
== 0, (
|
||||
f"tp_size={self.tensor_parallel_size} must be divisible by"
|
||||
f"dcp_size={self.decode_context_parallel_size}."
|
||||
)
|
||||
|
||||
cache_config = CacheConfig(
|
||||
block_size=self.block_size,
|
||||
gpu_memory_utilization=self.gpu_memory_utilization,
|
||||
@@ -1306,6 +1322,7 @@ class EngineArgs:
|
||||
distributed_executor_backend=self.distributed_executor_backend,
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
|
||||
Reference in New Issue
Block a user