use the same stream for cuda graph catpure and replay for NCCL (#29207)
Signed-off-by: Amir Samani <asamani@nvidia.com> Signed-off-by: youkaichao <youkaichao@gmail.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -293,7 +293,7 @@ class CommunicatorBenchmark:
|
|||||||
graph = torch.cuda.CUDAGraph()
|
graph = torch.cuda.CUDAGraph()
|
||||||
graph_pool = torch.cuda.graph_pool_handle()
|
graph_pool = torch.cuda.graph_pool_handle()
|
||||||
set_graph_pool_id(graph_pool)
|
set_graph_pool_id(graph_pool)
|
||||||
with torch.cuda.graph(graph, pool=graph_pool):
|
with torch.cuda.graph(graph, pool=graph_pool, stream=stream):
|
||||||
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
for _ in range(CUDA_GRAPH_CAPTURE_CYCLES):
|
||||||
allreduce_fn(graph_input)
|
allreduce_fn(graph_input)
|
||||||
|
|
||||||
|
|||||||
@@ -99,30 +99,18 @@ def _test_stream_thread(main_expected_stream: torch.cuda.Stream):
|
|||||||
|
|
||||||
|
|
||||||
def test_current_stream_multithread():
|
def test_current_stream_multithread():
|
||||||
from vllm.platforms import current_platform
|
|
||||||
|
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
pytest.skip("CUDA not available")
|
pytest.skip("CUDA not available")
|
||||||
|
|
||||||
if current_platform.is_rocm():
|
main_dedicated_stream = current_stream()
|
||||||
main_dedicated_stream = current_stream()
|
|
||||||
|
|
||||||
assert main_dedicated_stream.cuda_stream != 0, (
|
assert main_dedicated_stream.cuda_stream != 0, (
|
||||||
"ROCm should create a dedicated stream, not use default stream (0x0)"
|
"ROCm/CUDA should create a dedicated stream, not use default stream (0x0)"
|
||||||
)
|
)
|
||||||
|
|
||||||
main_stream_again = current_stream()
|
main_stream_again = current_stream()
|
||||||
assert main_stream_again == main_dedicated_stream, (
|
assert main_stream_again == main_dedicated_stream, (
|
||||||
"Multiple calls to current_stream should return the same dedicated stream"
|
"Multiple calls to current_stream should return the same dedicated stream"
|
||||||
)
|
)
|
||||||
|
|
||||||
_test_stream_thread(main_dedicated_stream)
|
_test_stream_thread(main_dedicated_stream)
|
||||||
else:
|
|
||||||
main_default_stream = torch.cuda.default_stream()
|
|
||||||
main_initial_stream = current_stream()
|
|
||||||
|
|
||||||
assert main_initial_stream == main_default_stream, (
|
|
||||||
"First call to current_stream should return default stream on CUDA"
|
|
||||||
)
|
|
||||||
|
|
||||||
_test_stream_thread(main_default_stream)
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo
|
|||||||
from vllm.forward_context import BatchDescriptor, get_forward_context
|
from vllm.forward_context import BatchDescriptor, get_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.torch_utils import weak_ref_tensors
|
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -263,7 +263,11 @@ class CUDAGraphWrapper:
|
|||||||
else:
|
else:
|
||||||
set_graph_pool_id(current_platform.graph_pool_handle())
|
set_graph_pool_id(current_platform.graph_pool_handle())
|
||||||
# mind-exploding: carefully manage the reference and memory.
|
# mind-exploding: carefully manage the reference and memory.
|
||||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
with torch.cuda.graph(
|
||||||
|
cudagraph,
|
||||||
|
pool=self.graph_pool,
|
||||||
|
stream=current_stream(),
|
||||||
|
):
|
||||||
# `output` is managed by pytorch's cudagraph pool
|
# `output` is managed by pytorch's cudagraph pool
|
||||||
output = self.runnable(*args, **kwargs)
|
output = self.runnable(*args, **kwargs)
|
||||||
if self.cudagraph_options.weak_ref_output:
|
if self.cudagraph_options.weak_ref_output:
|
||||||
|
|||||||
@@ -465,9 +465,13 @@ def current_stream() -> torch.cuda.Stream:
|
|||||||
# when this function is called before any stream is set,
|
# when this function is called before any stream is set,
|
||||||
# we return the default stream.
|
# we return the default stream.
|
||||||
# On ROCm using the default 0 stream in combination with RCCL
|
# On ROCm using the default 0 stream in combination with RCCL
|
||||||
# is hurting performance. Therefore creating a dedicated stream
|
# is hurting performance.
|
||||||
# per process
|
# On CUDA, we capture and replay cudagraph on the same stream,
|
||||||
if current_platform.is_rocm():
|
# so we need to avoid using the default stream as well. The default
|
||||||
|
# stream cannot be used for cudagraph capture, see
|
||||||
|
# https://github.com/pytorch/pytorch/blob/42ad9edfb754743fdae3276ade43de000beb4f60/aten/src/ATen/cuda/CUDAGraph.cpp#L77
|
||||||
|
# for more details. Therefore, we create a dedicated stream per process.
|
||||||
|
if current_platform.is_rocm() or current_platform.is_cuda():
|
||||||
# torch.cuda.set_stream here is the alias of _pathed_set_stream
|
# torch.cuda.set_stream here is the alias of _pathed_set_stream
|
||||||
torch.cuda.set_stream(torch.cuda.Stream())
|
torch.cuda.set_stream(torch.cuda.Stream())
|
||||||
elif current_platform.is_cpu():
|
elif current_platform.is_cpu():
|
||||||
|
|||||||
Reference in New Issue
Block a user