use the same stream for cuda graph catpure and replay for NCCL (#29207)

Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Amir Samani
2025-12-25 03:10:03 -08:00
committed by GitHub
parent 2532f437ee
commit 030fc44914
4 changed files with 23 additions and 27 deletions

View File

@@ -293,7 +293,7 @@ class CommunicatorBenchmark:
graph = torch.cuda.CUDAGraph()
graph_pool = torch.cuda.graph_pool_handle()
set_graph_pool_id(graph_pool)
with torch.cuda.graph(graph, pool=graph_pool):
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
allreduce_fn(graph_input)