[Feature] Prefill Context Parallel (PCP) basic support (#28718)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com>
Co-authored-by: LookAround <lixushi@huawei.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
Qiu
2025-11-20 04:52:44 +08:00
committed by GitHub
parent 02f5903b84
commit 2fd893b4ce
27 changed files with 399 additions and 114 deletions

View File

@@ -8,7 +8,11 @@ import torch
import vllm.envs as envs
from vllm.config import ParallelConfig
from vllm.distributed import get_dp_group, get_tensor_model_parallel_rank
from vllm.distributed import (
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.ocp_mx_utils import (
OCP_MX_DTYPES,
@@ -684,9 +688,11 @@ FUSED_MOE_UNQUANTIZED_CONFIG: FusedMoEQuantConfig = FusedMoEQuantConfig.make()
@dataclass
class FusedMoEParallelConfig:
tp_size: int
pcp_size: int
dp_size: int
ep_size: int
tp_rank: int
pcp_rank: int
dp_rank: int
ep_rank: int
@@ -713,19 +719,22 @@ class FusedMoEParallelConfig:
return self.use_all2all_kernels and self.all2all_backend == "deepep_low_latency"
@staticmethod
def flatten_tp_across_dp(
tp_size: int, dp_size: int, dp_rank: int
def flatten_tp_across_dp_and_pcp(
tp_size: int, dp_size: int, dp_rank: int, pcp_size: int, pcp_rank: int
) -> tuple[int, int]:
tp_rank = 0 if tp_size == 1 else get_tensor_model_parallel_rank()
# There are actually dp_size * tp_size devices. Update tp_size
# and tp_rank so we shard across all devices.
flatten_tp_size = dp_size * tp_size
flatten_tp_rank = dp_rank * tp_size + tp_rank
# There are actually dp_size * pcp_size * tp_size devices.
# Update tp_size and tp_rank so we shard across all devices.
flatten_tp_size = dp_size * pcp_size * tp_size
flatten_tp_rank = dp_rank * pcp_size * tp_size + pcp_rank * tp_size + tp_rank
return flatten_tp_size, flatten_tp_rank
@staticmethod
def make(
tp_size_: int, dp_size_: int, vllm_parallel_config: ParallelConfig
tp_size_: int,
pcp_size_: int,
dp_size_: int,
vllm_parallel_config: ParallelConfig,
) -> "FusedMoEParallelConfig":
"""
Determine MoE parallel configuration. Based on the input `tp_size_`,
@@ -734,19 +743,22 @@ class FusedMoEParallelConfig:
Args:
tp_size_ (int): `tp_size` passed into the FusedMoE constructor.
pcp_size_ (int): `pcp_size` passed into the FusedMoE constructor.
dp_size_ (int): `dp_size` passed into the FusedMoE constructor.
vllm_parallel_config (ParallelConfig): vLLM's parallel config
object which contains the `enable_expert_parallel` flag.
Examples:
When there is no parallelism requested,
i.e. `tp_size_` = `dp_size_` = 1, we simply return the sizes
i.e. `tp_size_` = `pcp_size_` = `dp_size_` = 1, we simply return the sizes
unaltered and the ranks set to 0.
Expert Parallelism is considered only when either `dp_size_` or
Expert Parallelism is considered only when either `dp_size_`, `pcp_size_` or
`tp_size_` is non trivial.
When TP = 2, DP = 1 and EP = False, the configuration on different
Note that PCP serves the same function as DP here.
When TP = 2, DP(PCP) = 1 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {1, 0} EP = {1, 0} //
@@ -754,7 +766,7 @@ class FusedMoEParallelConfig:
- device 1 : TP = {2, 1} DP = {1, 0} EP = {1, 0}
- Comment : Tensors are sharded across 2 devices.
When TP = 1, DP = 2 and EP = False, the configuration on different
When TP = 1, DP(PCP) = 2 and EP = False, the configuration on different
devices:
- device 0 : TP = {2, 0} DP = {2, 0} EP = {1, 0}
@@ -762,7 +774,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded
across 2 decvices.
When TP = 2, DP = 2 and EP = False, the configuration on different
When TP = 2, DP(PCP) = 2 and EP = False, the configuration on different
devices:
- device 0: TP = {4, 0} DP = {2, 0} EP = {1, 0}
@@ -772,14 +784,14 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the tensors are sharded
across 4 devices.
When, TP = 2, DP = 1 and EP = True, the configuration on different
When, TP = 2, DP(PCP) = 1 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {1, 0} EP = {2, 0}
- device 1: TP = {1, 0} DP = {1, 0} EP = {2, 1}
- Comment: The experts are split between the 2 devices.
When, TP = 1, DP = 2 and EP = True, the configuration on different
When, TP = 1, DP(PCP) = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {2, 0}
@@ -787,7 +799,7 @@ class FusedMoEParallelConfig:
- Comment: There are 2 engine instances and the experts are split
between the 2 devices.
When TP = 2, DP = 2 and EP = True, the configuration on different
When TP = 2, DP(PCP) = 2 and EP = True, the configuration on different
devices:
- device 0: TP = {1, 0} DP = {2, 0} EP = {4, 0}
@@ -798,18 +810,25 @@ class FusedMoEParallelConfig:
between the 4 devices.
"""
use_ep = dp_size_ * tp_size_ > 1 and vllm_parallel_config.enable_expert_parallel
use_ep = (
dp_size_ * pcp_size_ * tp_size_ > 1
and vllm_parallel_config.enable_expert_parallel
)
dp_size = dp_size_
dp_rank = get_dp_group().rank_in_group if dp_size > 1 else 0
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size_, dp_size_, dp_rank
pcp_size = pcp_size_
pcp_rank = get_pcp_group().rank_in_group if pcp_size > 1 else 0
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size_, dp_size_, dp_rank, pcp_size_, pcp_rank
)
if not use_ep:
return FusedMoEParallelConfig(
tp_size=tp_size,
tp_rank=tp_rank,
pcp_size=pcp_size,
pcp_rank=pcp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=1,
@@ -826,6 +845,8 @@ class FusedMoEParallelConfig:
return FusedMoEParallelConfig(
tp_size=1,
tp_rank=0,
pcp_size=pcp_size,
pcp_rank=pcp_rank,
dp_size=dp_size,
dp_rank=dp_rank,
ep_size=ep_size,

View File

@@ -18,6 +18,7 @@ from vllm.config.parallel import ExpertPlacementStrategy
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
@@ -343,6 +344,7 @@ class FusedMoE(CustomOp):
tp_size: int | None = None,
ep_size: int | None = None,
dp_size: int | None = None,
pcp_size: int | None = None,
prefix: str = "",
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
@@ -398,12 +400,14 @@ class FusedMoE(CustomOp):
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
)
dp_size_ = dp_size if dp_size is not None else get_dp_group().world_size
pcp_size_ = pcp_size if pcp_size is not None else get_pcp_group().world_size
self.is_sequence_parallel = is_sequence_parallel
self.sp_size = tp_size_ if is_sequence_parallel else 1
self.moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
tp_size_=tp_size_,
pcp_size_=pcp_size_,
dp_size_=dp_size_,
vllm_parallel_config=vllm_config.parallel_config,
)
@@ -679,6 +683,10 @@ class FusedMoE(CustomOp):
def dp_size(self):
return self.moe_parallel_config.dp_size
@property
def pcp_size(self):
return self.moe_parallel_config.pcp_size
@property
def ep_size(self):
return self.moe_parallel_config.ep_size
@@ -691,6 +699,10 @@ class FusedMoE(CustomOp):
def dp_rank(self):
return self.moe_parallel_config.dp_rank
@property
def pcp_rank(self):
return self.moe_parallel_config.pcp_rank
@property
def ep_rank(self):
return self.moe_parallel_config.ep_rank
@@ -1871,6 +1883,19 @@ class FusedMoE(CustomOp):
assert self.shared_experts is not None
shared_output = self.shared_experts(hidden_states)
# NOTE: Similar with DP, PCP also needs dispatch and combine. For
# simplicity, AgRsAll2All was added separately for PCP here. Maybe
# we should modify All2AllManager abstract to better support PCP.
if self.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
hidden_states,
dim=0,
)
router_logits = get_pcp_group().all_gather(
router_logits,
dim=0,
)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
@@ -1925,6 +1950,13 @@ class FusedMoE(CustomOp):
def combine_output(states: torch.Tensor) -> torch.Tensor:
if do_naive_dispatch_combine:
states = get_ep_group().combine(states, self.is_sequence_parallel)
if self.pcp_size > 1:
states = get_pcp_group().reduce_scatter(
states,
dim=0,
)
return states
if self.shared_experts is not None:

View File

@@ -13,6 +13,7 @@ from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (
get_dp_group,
get_ep_group,
get_pcp_group,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
@@ -322,10 +323,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size
@@ -507,10 +510,12 @@ class GptOssModel(nn.Module):
# In MoE, we need to flatten the tensor parallel size across the data
# parallel size when EP is disabled.
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp(
tp_size, tp_rank = FusedMoEParallelConfig.flatten_tp_across_dp_and_pcp(
tp_size=get_tensor_model_parallel_world_size(),
dp_size=get_dp_group().world_size,
dp_rank=get_dp_group().rank_in_group,
pcp_size=get_pcp_group().world_size,
pcp_rank=get_pcp_group().rank_in_group,
)
intermediate_size = self.config.intermediate_size