[Core][Distributed] remove graph mode function (#4818)

This commit is contained in:
youkaichao
2024-05-16 10:59:52 -07:00
committed by GitHub
parent b5853f9963
commit e08188081b
4 changed files with 63 additions and 54 deletions

View File

@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture():
with graph_capture() as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),
@@ -62,7 +62,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
device=torch.cuda.current_device())
torch.cuda.synchronize()
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.cuda.graph(graph,
stream=graph_capture_context.stream):
for i in range(num_communication):
out1 = tensor_model_parallel_all_reduce(inp1)
# the input buffer is immediately modified to test