[Feature] Prefill Context Parallel (PCP) basic support (#28718)
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com> Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com> Co-authored-by: LookAround <lixushi@huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
@@ -27,6 +27,7 @@ class KVCacheCoordinator(ABC):
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.max_model_len = max_model_len
|
||||
@@ -44,6 +45,7 @@ class KVCacheCoordinator(ABC):
|
||||
block_pool=self.block_pool,
|
||||
kv_cache_group_id=i,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
for i, kv_cache_group in enumerate(self.kv_cache_config.kv_cache_groups)
|
||||
)
|
||||
@@ -210,6 +212,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
use_eagle: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@@ -218,6 +221,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
|
||||
False,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
self.num_single_type_manager = len(self.single_type_managers)
|
||||
|
||||
@@ -250,6 +254,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@@ -258,12 +263,16 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
self.kv_cache_spec = self.kv_cache_config.kv_cache_groups[0].kv_cache_spec
|
||||
self.block_size = self.kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
self.pcp_world_size = pcp_world_size
|
||||
if dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
if pcp_world_size > 1:
|
||||
self.block_size *= pcp_world_size
|
||||
assert len(self.kv_cache_config.kv_cache_groups) == 1, (
|
||||
"UnitaryKVCacheCoordinator assumes only one kv cache group"
|
||||
)
|
||||
@@ -281,6 +290,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
|
||||
kv_cache_spec=self.kv_cache_spec,
|
||||
use_eagle=self.use_eagle,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
return hit_blocks, len(hit_blocks[0]) * self.block_size
|
||||
|
||||
@@ -302,6 +312,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
):
|
||||
super().__init__(
|
||||
kv_cache_config,
|
||||
@@ -310,8 +321,10 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support hybrid attn now."
|
||||
assert pcp_world_size == 1, "PCP not support hybrid attn now."
|
||||
self.verify_and_split_kv_cache_groups()
|
||||
|
||||
def verify_and_split_kv_cache_groups(self) -> None:
|
||||
@@ -452,6 +465,7 @@ def get_kv_cache_coordinator(
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool,
|
||||
dcp_world_size: int,
|
||||
pcp_world_size: int,
|
||||
) -> KVCacheCoordinator:
|
||||
if not enable_caching:
|
||||
return KVCacheCoordinatorNoPrefixCache(
|
||||
@@ -460,6 +474,7 @@ def get_kv_cache_coordinator(
|
||||
use_eagle,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
if len(kv_cache_config.kv_cache_groups) == 1:
|
||||
return UnitaryKVCacheCoordinator(
|
||||
@@ -469,6 +484,7 @@ def get_kv_cache_coordinator(
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
return HybridKVCacheCoordinator(
|
||||
kv_cache_config,
|
||||
@@ -477,4 +493,5 @@ def get_kv_cache_coordinator(
|
||||
enable_caching,
|
||||
enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
|
||||
@@ -100,6 +100,7 @@ class KVCacheManager:
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> None:
|
||||
self.max_model_len = max_model_len
|
||||
|
||||
@@ -124,12 +125,9 @@ class KVCacheManager:
|
||||
0
|
||||
].kv_cache_spec.block_size
|
||||
|
||||
if dcp_world_size > 1:
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1
|
||||
# Note(hc): need revisit. When both DCP and any future
|
||||
# PCP are enabled, the block_size may need to be scaled
|
||||
# by a factor of dcp_size × pcp_size?
|
||||
self.block_size *= dcp_world_size
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
|
||||
self.coordinator = get_kv_cache_coordinator(
|
||||
kv_cache_config=kv_cache_config,
|
||||
@@ -138,6 +136,7 @@ class KVCacheManager:
|
||||
enable_caching=self.enable_caching,
|
||||
enable_kv_cache_events=enable_kv_cache_events,
|
||||
dcp_world_size=dcp_world_size,
|
||||
pcp_world_size=pcp_world_size,
|
||||
)
|
||||
self.num_kv_cache_groups = len(kv_cache_config.kv_cache_groups)
|
||||
self.block_pool = self.coordinator.block_pool
|
||||
|
||||
@@ -1219,11 +1219,16 @@ def _report_kv_cache_config(
|
||||
// len(kv_cache_config.kv_cache_groups)
|
||||
* min_block_size
|
||||
)
|
||||
if vllm_config.parallel_config.decode_context_parallel_size > 1:
|
||||
num_tokens *= vllm_config.parallel_config.decode_context_parallel_size
|
||||
dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
if pcp_size * dcp_size > 1:
|
||||
num_tokens *= pcp_size * dcp_size
|
||||
logger.info(
|
||||
"Multiplying the GPU KV cache size by the dcp_world_size %d.",
|
||||
vllm_config.parallel_config.decode_context_parallel_size,
|
||||
"Multiplying the GPU KV cache size by the cp_world_size %d "
|
||||
"(pcp_world_size %d * dcp_world_size %d).",
|
||||
pcp_size * dcp_size,
|
||||
pcp_size,
|
||||
dcp_size,
|
||||
)
|
||||
num_tokens_str = f"{num_tokens:,}"
|
||||
logger.info_once("GPU KV cache size: %s tokens", num_tokens_str, scope="local")
|
||||
|
||||
@@ -121,6 +121,7 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
self.block_size = block_size
|
||||
self.dcp_world_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
self.pcp_world_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
|
||||
# req_id -> Request
|
||||
self.requests: dict[str, Request] = {}
|
||||
@@ -183,6 +184,7 @@ class Scheduler(SchedulerInterface):
|
||||
log_stats=self.log_stats,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
dcp_world_size=self.dcp_world_size,
|
||||
pcp_world_size=self.pcp_world_size,
|
||||
)
|
||||
self.use_pp = self.parallel_config.pipeline_parallel_size > 1
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
block_pool: BlockPool,
|
||||
kv_cache_group_id: int,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> None:
|
||||
"""
|
||||
Initializes the SingleTypeKVCacheManager.
|
||||
@@ -42,8 +43,9 @@ class SingleTypeKVCacheManager(ABC):
|
||||
"""
|
||||
self.block_size = kv_cache_spec.block_size
|
||||
self.dcp_world_size = dcp_world_size
|
||||
if self.dcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size
|
||||
self.pcp_world_size = pcp_world_size
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
self.block_size *= dcp_world_size * pcp_world_size
|
||||
self.kv_cache_spec = kv_cache_spec
|
||||
self.block_pool = block_pool
|
||||
|
||||
@@ -212,6 +214,7 @@ class SingleTypeKVCacheManager(ABC):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
Get the longest cache hit prefix of the blocks that is not longer than
|
||||
@@ -303,6 +306,7 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(
|
||||
kv_cache_spec, (FullAttentionSpec, ChunkedLocalAttentionSpec)
|
||||
@@ -314,8 +318,8 @@ class FullAttentionManager(SingleTypeKVCacheManager):
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
block_size = kv_cache_spec.block_size
|
||||
if dcp_world_size > 1:
|
||||
block_size *= dcp_world_size
|
||||
if dcp_world_size * pcp_world_size > 1:
|
||||
block_size *= dcp_world_size * pcp_world_size
|
||||
max_num_blocks = max_length // block_size
|
||||
for block_hash in itertools.islice(block_hashes, max_num_blocks):
|
||||
# block_hashes is a chain of block hashes. If a block hash is not
|
||||
@@ -362,11 +366,13 @@ class SlidingWindowManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, SlidingWindowSpec), (
|
||||
"SlidingWindowManager can only be used for sliding window groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support sliding window attn now."
|
||||
assert pcp_world_size == 1, "PCP not support sliding window attn now."
|
||||
|
||||
# The number of contiguous blocks needed for prefix cache hit.
|
||||
# -1 since the input token itself is also included in the window
|
||||
@@ -476,6 +482,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
"""
|
||||
For chunked local attention, we need to find the longest cache hit
|
||||
@@ -516,6 +523,7 @@ class ChunkedLocalAttentionManager(SingleTypeKVCacheManager):
|
||||
"Hybrid KV cache is not supported for " + "eagle + chunked local attention."
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support chunked local attn now."
|
||||
assert pcp_world_size == 1, "PCP not support chunked local attn now."
|
||||
max_num_blocks = max_length // kv_cache_spec.block_size
|
||||
if max_length > 0:
|
||||
local_attention_start_idx = (
|
||||
@@ -611,11 +619,13 @@ class MambaManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, MambaSpec), (
|
||||
"MambaManager can only be used for mamba groups"
|
||||
)
|
||||
assert dcp_world_size == 1, "DCP not support mamba now."
|
||||
assert pcp_world_size == 1, "PCP not support mamba now."
|
||||
computed_blocks: tuple[list[KVCacheBlock], ...] = tuple(
|
||||
[] for _ in range(len(kv_cache_group_ids))
|
||||
)
|
||||
@@ -705,6 +715,7 @@ class CrossAttentionManager(SingleTypeKVCacheManager):
|
||||
kv_cache_spec: KVCacheSpec,
|
||||
use_eagle: bool,
|
||||
dcp_world_size: int = 1,
|
||||
pcp_world_size: int = 1,
|
||||
) -> tuple[list[KVCacheBlock], ...]:
|
||||
assert isinstance(kv_cache_spec, CrossAttentionSpec), (
|
||||
"CrossAttentionManager can only be used for cross-attention groups"
|
||||
|
||||
Reference in New Issue
Block a user