use the same stream for cuda graph catpure and replay for NCCL (#29207)
Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -293,7 +293,7 @@ class CommunicatorBenchmark:
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
graph_pool = torch.cuda.graph_pool_handle()
|
||||
set_graph_pool_id(graph_pool)
|
||||
with torch.cuda.graph(graph, pool=graph_pool):
|
||||
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
|
||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||
allreduce_fn(graph_input)
|
||||
|
||||
|
||||
@@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream):
|
||||
|
||||
|
||||
def test_current_stream_multithread():
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA not available")
|
||||
|
||||
if current_platform.is_rocm():
|
||||
main_dedicated_stream = current_stream()
|
||||
main_dedicated_stream = current_stream()
|
||||
|
||||
assert main_dedicated_stream.cuda_stream != 0, (
|
||||
"ROCm should create a dedicated stream, not use default stream (0x0)"
|
||||
)
|
||||
assert main_dedicated_stream.cuda_stream != 0, (
|
||||
"ROCm/CUDA should create a dedicated stream, not use default stream (0x0)"
|
||||
)
|
||||
|
||||
main_stream_again = current_stream()
|
||||
assert main_stream_again == main_dedicated_stream, (
|
||||
"Multiple calls to current_stream should return the same dedicated stream"
|
||||
)
|
||||
main_stream_again = current_stream()
|
||||
assert main_stream_again == main_dedicated_stream, (
|
||||
"Multiple calls to current_stream should return the same dedicated stream"
|
||||
)
|
||||
|
||||
_test_stream_thread(main_dedicated_stream)
|
||||
else:
|
||||
main_default_stream = torch.cuda.default_stream()
|
||||
main_initial_stream = current_stream()
|
||||
|
||||
assert main_initial_stream == main_default_stream, (
|
||||
"First call to current_stream should return default stream on CUDA"
|
||||
)
|
||||
|
||||
_test_stream_thread(main_default_stream)
|
||||
_test_stream_thread(main_dedicated_stream)
|
||||
|
||||
@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo
|
||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import weak_ref_tensors
|
||||
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -263,7 +263,11 @@ class CUDAGraphWrapper:
|
||||
else:
|
||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
with torch.cuda.graph(
|
||||
cudagraph,
|
||||
pool=self.graph_pool,
|
||||
stream=current_stream(),
|
||||
):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = self.runnable(*args, **kwargs)
|
||||
if self.cudagraph_options.weak_ref_output:
|
||||
|
||||
@@ -465,9 +465,13 @@ def current_stream() -> torch.cuda.Stream:
|
||||
# when this function is called before any stream is set,
|
||||
# we return the default stream.
|
||||
# On ROCm using the default 0 stream in combination with RCCL
|
||||
# is hurting performance. Therefore creating a dedicated stream
|
||||
# per process
|
||||
if current_platform.is_rocm():
|
||||
# is hurting performance.
|
||||
# On CUDA, we capture and replay cudagraph on the same stream,
|
||||
# so we need to avoid using the default stream as well. The default
|
||||
# stream cannot be used for cudagraph capture, see
|
||||
# https://github.com/pytorch/pytorch/blob/42ad9edfb754743fdae3276ade43de000beb4f60/aten/src/ATen/cuda/CUDAGraph.cpp#L77
|
||||
# for more details. Therefore, we create a dedicated stream per process.
|
||||
if current_platform.is_rocm() or current_platform.is_cuda():
|
||||
# torch.cuda.set_stream here is the alias of _pathed_set_stream
|
||||
torch.cuda.set_stream(torch.cuda.Stream())
|
||||
elif current_platform.is_cpu():
|
||||
|
||||
Reference in New Issue
Block a user