[Feature] Prefill Context Parallel (PCP) basic support (#28718)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com>
Co-authored-by: LookAround <lixushi@huawei.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
Qiu
2025-11-20 04:52:44 +08:00
committed by GitHub
parent 02f5903b84
commit 2fd893b4ce
27 changed files with 399 additions and 114 deletions

View File

@@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
self.kv_cache_config = kv_cache_config
self.max_model_len = max_model_len
@@ -44,6 +45,7 @@ class KVCacheCoordinator(ABC):
block_pool=self.block_pool,
kv_cache_group_id=i,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
)
@@ -210,6 +212,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
use_eagle: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@@ -218,6 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
False,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
self.num_single_type_manager = len(self.single_type_managers)
@@ -250,6 +254,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@@ -258,12 +263,16 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
self.block_size = self.kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
self.pcp_world_size = pcp_world_size
if dcp_world_size > 1:
self.block_size *= dcp_world_size
if pcp_world_size > 1:
self.block_size *= pcp_world_size
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
"UnitaryKVCacheCoordinator assumes only one kv cache group"
)
@@ -281,6 +290,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
kv_cache_spec=self.kv_cache_spec,
use_eagle=self.use_eagle,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
return hit_blocks, len(hit_blocks[0]) * self.block_size
@@ -302,6 +312,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
):
super().__init__(
kv_cache_config,
@@ -310,8 +321,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
assert dcp_world_size == 1, "DCP not support hybrid attn now."
assert pcp_world_size == 1, "PCP not support hybrid attn now."
self.verify_and_split_kv_cache_groups()
def verify_and_split_kv_cache_groups(self) -> None:
@@ -452,6 +465,7 @@ def get_kv_cache_coordinator(
enable_caching: bool,
enable_kv_cache_events: bool,
dcp_world_size: int,
pcp_world_size: int,
) -> KVCacheCoordinator:
if not enable_caching:
return KVCacheCoordinatorNoPrefixCache(
@@ -460,6 +474,7 @@ def get_kv_cache_coordinator(
use_eagle,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
if len(kv_cache_config.kv_cache_groups) == 1:
return UnitaryKVCacheCoordinator(
@@ -469,6 +484,7 @@ def get_kv_cache_coordinator(
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
return HybridKVCacheCoordinator(
kv_cache_config,
@@ -477,4 +493,5 @@ def get_kv_cache_coordinator(
enable_caching,
enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)

View File

@@ -100,6 +100,7 @@ class KVCacheManager:
log_stats: bool = False,
enable_kv_cache_events: bool = False,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> None:
self.max_model_len = max_model_len
@@ -124,12 +125,9 @@ class KVCacheManager:
0
].kv_cache_spec.block_size
if dcp_world_size > 1:
if dcp_world_size * pcp_world_size > 1:
assert len(kv_cache_config.kv_cache_groups) == 1
# Note(hc): need revisit. When both DCP and any future
# PCP are enabled, the block_size may need to be scaled
# by a factor of dcp_size × pcp_size?
self.block_size *= dcp_world_size
self.block_size *= dcp_world_size * pcp_world_size
self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config,
@@ -138,6 +136,7 @@ class KVCacheManager:
enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events,
dcp_world_size=dcp_world_size,
pcp_world_size=pcp_world_size,
)
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
self.block_pool = self.coordinator.block_pool

View File

@@ -1219,11 +1219,16 @@ def _report_kv_cache_config(
// len(kv_cache_config.kv_cache_groups)
* min_block_size
)
if vllm_config.parallel_config.decode_context_parallel_size > 1:
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
if pcp_size * dcp_size > 1:
num_tokens *= pcp_size * dcp_size
logger.info(
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
vllm_config.parallel_config.decode_context_parallel_size,
"Multiplying the GPU KV cache size by the cp_world_size %d "
"(pcp_world_size %d * dcp_world_size %d).",
pcp_size * dcp_size,
pcp_size,
dcp_size,
)
num_tokens_str = f"{num_tokens:,}"
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")

View File

@@ -121,6 +121,7 @@ class Scheduler(SchedulerInterface):
self.block_size = block_size
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
# req_id -> Request
self.requests: dict[str, Request] = {}
@@ -183,6 +184,7 @@ class Scheduler(SchedulerInterface):
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
dcp_world_size=self.dcp_world_size,
pcp_world_size=self.pcp_world_size,
)
self.use_pp = self.parallel_config.pipeline_parallel_size > 1

View File

@@ -32,6 +32,7 @@ class SingleTypeKVCacheManager(ABC):
block_pool: BlockPool,
kv_cache_group_id: int,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> None:
"""
Initializes the SingleTypeKVCacheManager.
@@ -42,8 +43,9 @@ class SingleTypeKVCacheManager(ABC):
"""
self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size
if self.dcp_world_size > 1:
self.block_size *= dcp_world_size
self.pcp_world_size = pcp_world_size
if dcp_world_size * pcp_world_size > 1:
self.block_size *= dcp_world_size * pcp_world_size
self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool
@@ -212,6 +214,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
Get the longest cache hit prefix of the blocks that is not longer than
@@ -303,6 +306,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
@@ -314,8 +318,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
[] for _ in range(len(kv_cache_group_ids))
)
block_size = kv_cache_spec.block_size
if dcp_world_size > 1:
block_size *= dcp_world_size
if dcp_world_size * pcp_world_size > 1:
block_size *= dcp_world_size * pcp_world_size
max_num_blocks = max_length // block_size
for block_hash in itertools.islice(block_hashes, max_num_blocks):
# block_hashes is a chain of block hashes. If a block hash is not
@@ -362,11 +366,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
"SlidingWindowManager can only be used for sliding window groups"
)
assert dcp_world_size == 1, "DCP not support sliding window attn now."
assert pcp_world_size == 1, "PCP not support sliding window attn now."
# The number of contiguous blocks needed for prefix cache hit.
# -1 since the input token itself is also included in the window
@@ -476,6 +482,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
"""
For chunked local attention, we need to find the longest cache hit
@@ -516,6 +523,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
)
assert dcp_world_size == 1, "DCP not support chunked local attn now."
assert pcp_world_size == 1, "PCP not support chunked local attn now."
max_num_blocks = max_length // kv_cache_spec.block_size
if max_length > 0:
local_attention_start_idx = (
@@ -611,11 +619,13 @@ class MambaManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, MambaSpec), (
"MambaManager can only be used for mamba groups"
)
assert dcp_world_size == 1, "DCP not support mamba now."
assert pcp_world_size == 1, "PCP not support mamba now."
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
[] for _ in range(len(kv_cache_group_ids))
)
@@ -705,6 +715,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
kv_cache_spec: KVCacheSpec,
use_eagle: bool,
dcp_world_size: int = 1,
pcp_world_size: int = 1,
) -> tuple[list[KVCacheBlock], ...]:
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
"CrossAttentionManager can only be used for cross-attention groups"