[DCP] Support Decode Context Parallel (DCP) for GQA with Flashinfer (#25438)

Signed-off-by: gaojc <1055866782@qq.com>
Signed-off-by: Jingchun Gao <gaojingchun1@huawei.com>
Signed-off-by: Jingchun Gao <63247409+gjc0824@users.noreply.github.com>
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
Co-authored-by: gaojingchun (A) <g00955623@china.huawei.com>
Co-authored-by: Jingchun Gao <gaojingchun1@huawei.com>
Co-authored-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Jingchun Gao
2025-11-14 19:24:10 +08:00
committed by GitHub
parent 41b92f7d38
commit 4516d44b7f
5 changed files with 331 additions and 51 deletions

View File

@@ -31,6 +31,7 @@ from vllm.distributed import destroy_distributed_environment, destroy_model_para
from vllm.distributed.device_communicators.shm_broadcast import Handle, MessageQueue
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import (
get_dcp_group,
get_dp_group,
get_ep_group,
get_pp_group,
@@ -726,6 +727,8 @@ class WorkerProc:
pp_rank = get_pp_group().rank_in_group
tp_size = get_tp_group().world_size
tp_rank = get_tp_group().rank_in_group
dcp_size = get_dcp_group().world_size
dcp_rank = get_dcp_group().rank_in_group
process_name = "Worker"
if dp_size > 1:
process_name += f"_DP{dp_rank}"
@@ -733,6 +736,8 @@ class WorkerProc:
process_name += f"_PP{pp_rank}"
if tp_size > 1:
process_name += f"_TP{tp_rank}"
if dcp_size > 1:
process_name += f"_DCP{dcp_rank}"
if enable_ep:
ep_rank = get_ep_group().rank_in_group
process_name += f"_EP{ep_rank}"