[Misc] Begin deprecation of get_tensor_model_*_group (#22494)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-08 16:11:54 +08:00
committed by GitHub
parent 1712543df6
commit 43c4f3d77c
3 changed files with 16 additions and 10 deletions

View File

@@ -10,8 +10,7 @@ import torch.distributed as dist
from vllm.distributed.communication_op import ( # noqa
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
get_tp_group, graph_capture)
from vllm.distributed.parallel_state import get_tp_group, graph_capture
from ..utils import (ensure_model_parallel_initialized,
init_test_distributed_environment, multi_process_parallel)
@@ -37,7 +36,7 @@ def graph_allreduce(
init_test_distributed_environment(tp_size, pp_size, rank,
distributed_init_port)
ensure_model_parallel_initialized(tp_size, pp_size)
group = get_tensor_model_parallel_group().device_group
group = get_tp_group().device_group
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily