[Model Runner V2] Init cuda graph pool when necessary (#33217)

Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
Xinyu Chen
2026-02-12 01:12:13 +08:00
committed by GitHub
parent fa7e0bfacf
commit ffb3d553cc
2 changed files with 6 additions and 2 deletions

View File

@@ -45,7 +45,9 @@ class CudaGraphManager:
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None
def needs_capture(self) -> bool:

View File

@@ -44,7 +44,9 @@ class EagleCudaGraphManager:
)
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = torch.cuda.graph_pool_handle()
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None:
return self.cudagraph_sizes.get(num_tokens)