[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)
This commit is contained in:
@@ -5,7 +5,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.communication_op import ( # noqa
|
||||
graph_capture_mode, tensor_model_parallel_all_reduce)
|
||||
graph_mode, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl_wrapper import NCCLLibrary
|
||||
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
|
||||
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
|
||||
device = torch.device(f"cuda:{torch.distributed.get_rank()}")
|
||||
ensure_model_parallel_initialized(2, 2)
|
||||
tensor = torch.ones(16, 1024, 1024, dtype=torch.float32, device=device)
|
||||
with graph_capture_mode():
|
||||
with graph_mode():
|
||||
# two tp groups can communicate independently
|
||||
if torch.distributed.get_rank() in [0, 1]:
|
||||
tensor = tensor_model_parallel_all_reduce(tensor)
|
||||
|
||||
Reference in New Issue
Block a user