From 030fc4491465d361e4bed626d76c184f8a7d8a07 Mon Sep 17 00:00:00 2001 From: Amir Samani Date: Thu, 25 Dec 2025 03:10:03 -0800 Subject: [PATCH] use the same stream for cuda graph catpure and replay for NCCL (#29207) Signed-off-by: Amir Samani Signed-off-by: youkaichao Co-authored-by: youkaichao --- .../kernels/benchmark_device_communicators.py | 2 +- tests/utils_/test_torch_utils.py | 30 ++++++------------- vllm/compilation/cuda_graph.py | 8 +++-- vllm/utils/torch_utils.py | 10 +++++-- 4 files changed, 23 insertions(+), 27 deletions(-) diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index b414efa6e..7b453fe7b 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -293,7 +293,7 @@ class CommunicatorBenchmark: graph = torch.cuda.CUDAGraph() graph_pool = torch.cuda.graph_pool_handle() set_graph_pool_id(graph_pool) - with torch.cuda.graph(graph, pool=graph_pool): + with torch.cuda.graph(graph, pool=graph_pool, stream=stream): for _ in range(CUDA_GRAPH_CAPTURE_CYCLES): allreduce_fn(graph_input) diff --git a/tests/utils_/test_torch_utils.py b/tests/utils_/test_torch_utils.py index 0a30b9727..f6a9486a1 100644 --- a/tests/utils_/test_torch_utils.py +++ b/tests/utils_/test_torch_utils.py @@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream): def test_current_stream_multithread(): - from vllm.platforms import current_platform - if not torch.cuda.is_available(): pytest.skip("CUDA not available") - if current_platform.is_rocm(): - main_dedicated_stream = current_stream() + main_dedicated_stream = current_stream() - assert main_dedicated_stream.cuda_stream != 0, ( - "ROCm should create a dedicated stream, not use default stream (0x0)" - ) + assert main_dedicated_stream.cuda_stream != 0, ( + "ROCm/CUDA should create a dedicated stream, not use default stream (0x0)" + ) - main_stream_again = current_stream() - assert main_stream_again == main_dedicated_stream, ( - "Multiple calls to current_stream should return the same dedicated stream" - ) + main_stream_again = current_stream() + assert main_stream_again == main_dedicated_stream, ( + "Multiple calls to current_stream should return the same dedicated stream" + ) - _test_stream_thread(main_dedicated_stream) - else: - main_default_stream = torch.cuda.default_stream() - main_initial_stream = current_stream() - - assert main_initial_stream == main_default_stream, ( - "First call to current_stream should return default stream on CUDA" - ) - - _test_stream_thread(main_default_stream) + _test_stream_thread(main_dedicated_stream) diff --git a/vllm/compilation/cuda_graph.py b/vllm/compilation/cuda_graph.py index 0748643a5..08cae27b1 100644 --- a/vllm/compilation/cuda_graph.py +++ b/vllm/compilation/cuda_graph.py @@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.utils.torch_utils import weak_ref_tensors +from vllm.utils.torch_utils import current_stream, weak_ref_tensors logger = init_logger(__name__) @@ -263,7 +263,11 @@ class CUDAGraphWrapper: else: set_graph_pool_id(current_platform.graph_pool_handle()) # mind-exploding: carefully manage the reference and memory. - with torch.cuda.graph(cudagraph, pool=self.graph_pool): + with torch.cuda.graph( + cudagraph, + pool=self.graph_pool, + stream=current_stream(), + ): # `output` is managed by pytorch's cudagraph pool output = self.runnable(*args, **kwargs) if self.cudagraph_options.weak_ref_output: diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index b82e0171b..db596052a 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -465,9 +465,13 @@ def current_stream() -> torch.cuda.Stream: # when this function is called before any stream is set, # we return the default stream. # On ROCm using the default 0 stream in combination with RCCL - # is hurting performance. Therefore creating a dedicated stream - # per process - if current_platform.is_rocm(): + # is hurting performance. + # On CUDA, we capture and replay cudagraph on the same stream, + # so we need to avoid using the default stream as well. The default + # stream cannot be used for cudagraph capture, see + # https://github.com/pytorch/pytorch/blob/42ad9edfb754743fdae3276ade43de000beb4f60/aten/src/ATen/cuda/CUDAGraph.cpp#L77 + # for more details. Therefore, we create a dedicated stream per process. + if current_platform.is_rocm() or current_platform.is_cuda(): # torch.cuda.set_stream here is the alias of _pathed_set_stream torch.cuda.set_stream(torch.cuda.Stream()) elif current_platform.is_cpu():