diff --git a/vllm/v1/worker/gpu/cudagraph_utils.py b/vllm/v1/worker/gpu/cudagraph_utils.py index bf55b99af..d5a22d6a0 100644 --- a/vllm/v1/worker/gpu/cudagraph_utils.py +++ b/vllm/v1/worker/gpu/cudagraph_utils.py @@ -45,7 +45,9 @@ class CudaGraphManager: ) self.graphs: dict[int, torch.cuda.CUDAGraph] = {} - self.pool = torch.cuda.graph_pool_handle() + self.pool = None + if self.cudagraph_mode != CUDAGraphMode.NONE: + self.pool = torch.cuda.graph_pool_handle() self.hidden_states: torch.Tensor | None = None def needs_capture(self) -> bool: diff --git a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py index 48e7cb110..1ea7ffcb5 100644 --- a/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py +++ b/vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py @@ -44,7 +44,9 @@ class EagleCudaGraphManager: ) self.graphs: dict[int, torch.cuda.CUDAGraph] = {} - self.pool = torch.cuda.graph_pool_handle() + self.pool = None + if self.cudagraph_mode != CUDAGraphMode.NONE: + self.pool = torch.cuda.graph_pool_handle() def get_cudagraph_size(self, num_tokens: int) -> int | None: return self.cudagraph_sizes.get(num_tokens)