[Feature] Prefill Context Parallel (PCP) basic support (#28718)
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Signed-off-by: FENP <yuanyongjie.yyj@antgroup.com> Signed-off-by: LookAround <lixushi@huawei.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: FENP <yuanyongjie.yyj@antgroup.com> Co-authored-by: LookAround <lixushi@huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: zhenwenqi2024 <zhenwenqi_2022@qq.com> Co-authored-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.distributed import get_dcp_group
|
||||
from vllm.distributed import get_dcp_group, get_pcp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
@@ -22,7 +22,7 @@ class BlockTable:
|
||||
pin_memory: bool,
|
||||
device: torch.device,
|
||||
kernel_block_size: int,
|
||||
dcp_kv_cache_interleave_size: int,
|
||||
cp_kv_cache_interleave_size: int,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -80,6 +80,13 @@ class BlockTable:
|
||||
else:
|
||||
self._kernel_block_arange = None
|
||||
|
||||
try:
|
||||
self.pcp_world_size = get_pcp_group().world_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group
|
||||
except AssertionError:
|
||||
# DCP might not be initialized in testing
|
||||
self.pcp_world_size = 1
|
||||
self.pcp_rank = 0
|
||||
try:
|
||||
self.dcp_world_size = get_dcp_group().world_size
|
||||
self.dcp_rank = get_dcp_group().rank_in_group
|
||||
@@ -87,7 +94,7 @@ class BlockTable:
|
||||
# DCP might not be initialized in testing
|
||||
self.dcp_world_size = 1
|
||||
self.dcp_rank = 0
|
||||
self.dcp_kv_cache_interleave_size = dcp_kv_cache_interleave_size
|
||||
self.cp_kv_cache_interleave_size = cp_kv_cache_interleave_size
|
||||
|
||||
def append_row(
|
||||
self,
|
||||
@@ -131,14 +138,16 @@ class BlockTable:
|
||||
# NOTE(woosuk): We can't simply use `token_indices // block_size`
|
||||
# here because M (max_model_len) is not necessarily divisible by
|
||||
# block_size.
|
||||
if self.dcp_world_size > 1:
|
||||
total_cp_world_size = self.pcp_world_size * self.dcp_world_size
|
||||
total_cp_rank = self.pcp_rank * self.dcp_world_size + self.dcp_rank
|
||||
if total_cp_world_size > 1:
|
||||
# Note(hc): The DCP implement store kvcache with an interleave
|
||||
# style, the kvcache for the token whose token_idx is i is
|
||||
# always stored on the GPU whose dcp_rank equals i % cp_world_size:
|
||||
|
||||
# Use a "virtual block" which equals to world_size * block_size
|
||||
# for block_table_indices calculation.
|
||||
virtual_block_size = self.block_size * self.dcp_world_size
|
||||
virtual_block_size = self.block_size * total_cp_world_size
|
||||
block_table_indices = (
|
||||
req_indices * self.max_num_blocks_per_req
|
||||
+ positions // virtual_block_size
|
||||
@@ -150,16 +159,16 @@ class BlockTable:
|
||||
virtual_block_offsets = positions % virtual_block_size
|
||||
mask = (
|
||||
virtual_block_offsets
|
||||
// self.dcp_kv_cache_interleave_size
|
||||
% self.dcp_world_size
|
||||
== self.dcp_rank
|
||||
// self.cp_kv_cache_interleave_size
|
||||
% total_cp_world_size
|
||||
== total_cp_rank
|
||||
)
|
||||
# Calculate local block_offsets
|
||||
block_offsets = (
|
||||
virtual_block_offsets
|
||||
// (self.dcp_world_size * self.dcp_kv_cache_interleave_size)
|
||||
* self.dcp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.dcp_kv_cache_interleave_size
|
||||
// (total_cp_world_size * self.cp_kv_cache_interleave_size)
|
||||
* self.cp_kv_cache_interleave_size
|
||||
+ virtual_block_offsets % self.cp_kv_cache_interleave_size
|
||||
)
|
||||
# Calculate slot_mapping
|
||||
slot_mapping = block_numbers * self.block_size + block_offsets
|
||||
@@ -253,7 +262,7 @@ class MultiGroupBlockTable:
|
||||
block_sizes: list[int],
|
||||
kernel_block_sizes: list[int],
|
||||
num_speculative_tokens: int = 0,
|
||||
dcp_kv_cache_interleave_size: int = 1,
|
||||
cp_kv_cache_interleave_size: int = 1,
|
||||
) -> None:
|
||||
# Note(hc): each dcp rank only store
|
||||
# (max_model_len//dcp_world_size) tokens in kvcache,
|
||||
@@ -283,7 +292,7 @@ class MultiGroupBlockTable:
|
||||
pin_memory,
|
||||
device,
|
||||
kernel_block_size,
|
||||
dcp_kv_cache_interleave_size,
|
||||
cp_kv_cache_interleave_size,
|
||||
)
|
||||
for block_size, kernel_block_size in zip(block_sizes, kernel_block_sizes)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user