[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)
Signed-off-by: gaojc <1055866782@qq.com> Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com> Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com> Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com> Co-authored-by: gaojingchun (A) <g00955623@china.huawei.com> Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com> Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
@@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para
|
||||
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dcp_group,
|
||||
get_dp_group,
|
||||
get_ep_group,
|
||||
get_pp_group,
|
||||
@@ -726,6 +727,8 @@ class WorkerProc:
|
||||
pp_rank = get_pp_group().rank_in_group
|
||||
tp_size = get_tp_group().world_size
|
||||
tp_rank = get_tp_group().rank_in_group
|
||||
dcp_size = get_dcp_group().world_size
|
||||
dcp_rank = get_dcp_group().rank_in_group
|
||||
process_name = "Worker"
|
||||
if dp_size > 1:
|
||||
process_name += f"_DP{dp_rank}"
|
||||
@@ -733,6 +736,8 @@ class WorkerProc:
|
||||
process_name += f"_PP{pp_rank}"
|
||||
if tp_size > 1:
|
||||
process_name += f"_TP{tp_rank}"
|
||||
if dcp_size > 1:
|
||||
process_name += f"_DCP{dcp_rank}"
|
||||
if enable_ep:
|
||||
ep_rank = get_ep_group().rank_in_group
|
||||
process_name += f"_EP{ep_rank}"
|
||||
|
||||
Reference in New Issue
Block a user