[ci][distributed] add tests for custom allreduce (#5689)
This commit is contained in:
@@ -11,7 +11,8 @@ from vllm.distributed.communication_op import ( # noqa
|
||||
from vllm.distributed.parallel_state import (get_tensor_model_parallel_group,
|
||||
get_tp_group, graph_capture)
|
||||
|
||||
from ..utils import (init_test_distributed_environment,
|
||||
from ..utils import (ensure_model_parallel_initialized,
|
||||
init_test_distributed_environment,
|
||||
multi_process_tensor_parallel)
|
||||
|
||||
random.seed(42)
|
||||
@@ -27,8 +28,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
torch.cuda.set_device(device)
|
||||
init_test_distributed_environment(tp_size, pp_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
group = get_tensor_model_parallel_group()
|
||||
ensure_model_parallel_initialized(tp_size, pp_size)
|
||||
group = get_tensor_model_parallel_group().device_group
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
# this is needed because device communicators might be created lazily
|
||||
|
||||
Reference in New Issue
Block a user