[Feature] Prefill Context Parallel (PCP) basic support (#28718)

Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com>
Signed-off-by: LookAround <lixushi@huawei.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com>
Co-authored-by: LookAround <lixushi@huawei.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com>
Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
Qiu
2025-11-20 04:52:44 +08:00
committed by GitHub
parent 02f5903b84
commit 2fd893b4ce
27 changed files with 399 additions and 114 deletions

View File

@@ -4,7 +4,7 @@
import numpy as np
import torch
from vllm.distributed import get_dcp_group
from vllm.distributed import get_dcp_group, get_pcp_group
from vllm.logger import init_logger
from vllm.utils.math_utils import cdiv
from vllm.v1.utils import CpuGpuBuffer
@@ -22,7 +22,7 @@ class BlockTable:
pin_memory: bool,
device: torch.device,
kernel_block_size: int,
dcp_kv_cache_interleave_size: int,
cp_kv_cache_interleave_size: int,
):
"""
Args:
@@ -80,6 +80,13 @@ class BlockTable:
else:
self._kernel_block_arange = None
try:
self.pcp_world_size = get_pcp_group().world_size
self.pcp_rank = get_pcp_group().rank_in_group
except AssertionError:
# DCP might not be initialized in testing
self.pcp_world_size = 1
self.pcp_rank = 0
try:
self.dcp_world_size = get_dcp_group().world_size
self.dcp_rank = get_dcp_group().rank_in_group
@@ -87,7 +94,7 @@ class BlockTable:
# DCP might not be initialized in testing
self.dcp_world_size = 1
self.dcp_rank = 0
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
def append_row(
self,
@@ -131,14 +138,16 @@ class BlockTable:
# NOTE(woosuk): We can't simply use `token_indices // block_size`
# here because M (max_model_len) is not necessarily divisible by
# block_size.
if self.dcp_world_size > 1:
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
if total_cp_world_size > 1:
# Note(hc): The DCP implement store kvcache with an interleave
# style, the kvcache for the token whose token_idx is i is
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
# Use a "virtual block" which equals to world_size * block_size
# for block_table_indices calculation.
virtual_block_size = self.block_size * self.dcp_world_size
virtual_block_size = self.block_size * total_cp_world_size
block_table_indices = (
req_indices * self.max_num_blocks_per_req
+ positions // virtual_block_size
@@ -150,16 +159,16 @@ class BlockTable:
virtual_block_offsets = positions % virtual_block_size
mask = (
virtual_block_offsets
// self.dcp_kv_cache_interleave_size
% self.dcp_world_size
== self.dcp_rank
// self.cp_kv_cache_interleave_size
% total_cp_world_size
== total_cp_rank
)
# Calculate local block_offsets
block_offsets = (
virtual_block_offsets
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
* self.dcp_kv_cache_interleave_size
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
* self.cp_kv_cache_interleave_size
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
)
# Calculate slot_mapping
slot_mapping = block_numbers * self.block_size + block_offsets
@@ -253,7 +262,7 @@ class MultiGroupBlockTable:
block_sizes: list[int],
kernel_block_sizes: list[int],
num_speculative_tokens: int = 0,
dcp_kv_cache_interleave_size: int = 1,
cp_kv_cache_interleave_size: int = 1,
) -> None:
# Note(hc): each dcp rank only store
# (max_model_len//dcp_world_size) tokens in kvcache,
@@ -283,7 +292,7 @@ class MultiGroupBlockTable:
pin_memory,
device,
kernel_block_size,
dcp_kv_cache_interleave_size,
cp_kv_cache_interleave_size,
)
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
]