use the same stream for cuda graph catpure and replay for NCCL (#29207)

Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Amir Samani
2025-12-25 03:10:03 -08:00
committed by GitHub
parent 2532f437ee
commit 030fc44914
4 changed files with 23 additions and 27 deletions

View File

@@ -465,9 +465,13 @@ def current_stream() -> torch.cuda.Stream:
# when this function is called before any stream is set,
# we return the default stream.
# On ROCm using the default 0 stream in combination with RCCL
# is hurting performance. Therefore creating a dedicated stream
# per process
if current_platform.is_rocm():
# is hurting performance.
# On CUDA, we capture and replay cudagraph on the same stream,
# so we need to avoid using the default stream as well. The default
# stream cannot be used for cudagraph capture, see
# https://github.com/pytorch/pytorch/blob/42ad9edfb754743fdae3276ade43de000beb4f60/aten/src/ATen/cuda/CUDAGraph.cpp#L77
# for more details. Therefore, we create a dedicated stream per process.
if current_platform.is_rocm() or current_platform.is_cuda():
# torch.cuda.set_stream here is the alias of _pathed_set_stream
torch.cuda.set_stream(torch.cuda.Stream())
elif current_platform.is_cpu():