use the same stream for cuda graph catpure and replay for NCCL (#29207)
Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -293,7 +293,7 @@ class CommunicatorBenchmark:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
graph_pool = torch.cuda.graph_pool_handle()
|
||||
set_graph_pool_id(graph_pool)
|
||||
with torch.cuda.graph(graph, pool=graph_pool):
|
||||
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
|
||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user