use the same stream for cuda graph catpure and replay for NCCL (#29207)

Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
Amir Samani
2025-12-25 03:10:03 -08:00
committed by GitHub
parent 2532f437ee
commit 030fc44914
4 changed files with 23 additions and 27 deletions

View File

@@ -18,7 +18,7 @@ from vllm.distributed.device_communicators.pynccl_allocator import set_graph_poo
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils.torch_utils import weak_ref_tensors
from vllm.utils.torch_utils import current_stream, weak_ref_tensors
logger = init_logger(__name__)
@@ -263,7 +263,11 @@ class CUDAGraphWrapper:
else:
set_graph_pool_id(current_platform.graph_pool_handle())
# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
with torch.cuda.graph(
cudagraph,
pool=self.graph_pool,
stream=current_stream(),
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
if self.cudagraph_options.weak_ref_output: