[Model Runner V2] Init cuda graph pool when necessary (#33217)
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
@@ -45,7 +45,9 @@ class CudaGraphManager:
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
|
||||
def needs_capture(self) -> bool:
|
||||
|
||||
@@ -44,7 +44,9 @@ class EagleCudaGraphManager:
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
return self.cudagraph_sizes.get(num_tokens)
|
||||
|
||||
Reference in New Issue
Block a user