[Model Runner V2] Init cuda graph pool when necessary (#33217)

Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
Xinyu Chen
2026-02-12 01:12:13 +08:00
committed by GitHub
parent fa7e0bfacf
commit ffb3d553cc
2 changed files with 6 additions and 2 deletions

View File

@@ -45,6 +45,8 @@ class CudaGraphManager:
) )
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle() self.pool = torch.cuda.graph_pool_handle()
self.hidden_states: torch.Tensor | None = None self.hidden_states: torch.Tensor | None = None

View File

@@ -44,6 +44,8 @@ class EagleCudaGraphManager:
) )
self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
self.pool = None
if self.cudagraph_mode != CUDAGraphMode.NONE:
self.pool = torch.cuda.graph_pool_handle() self.pool = torch.cuda.graph_pool_handle()
def get_cudagraph_size(self, num_tokens: int) -> int | None: def get_cudagraph_size(self, num_tokens: int) -> int | None: