[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)

This commit is contained in:
youkaichao
2024-05-12 17:47:59 -07:00
committed by GitHub
parent a7be4d0072
commit 702bee461f
10 changed files with 327 additions and 226 deletions

View File

@@ -5,7 +5,7 @@ import pytest
import torch
from vllm.distributed.communication_op import ( # noqa
graph_capture_mode, tensor_model_parallel_all_reduce)
graph_mode, tensor_model_parallel_all_reduce)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
ensure_model_parallel_initialized(2, 2)
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
with graph_capture_mode():
with graph_mode():
# two tp groups can communicate independently
if torch.distributed.get_rank() in [0, 1]:
tensor = tensor_model_parallel_all_reduce(tensor)