[Model Runner V2] Init cuda graph pool when necessary (#33217)
Signed-off-by: Xinyu Chen <xinyu1.chen@intel.com>
This commit is contained in:
@@ -45,6 +45,8 @@ class CudaGraphManager:
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
self.hidden_states: torch.Tensor | None = None
|
||||
|
||||
|
||||
@@ -44,6 +44,8 @@ class EagleCudaGraphManager:
|
||||
)
|
||||
|
||||
self.graphs: dict[int, torch.cuda.CUDAGraph] = {}
|
||||
self.pool = None
|
||||
if self.cudagraph_mode != CUDAGraphMode.NONE:
|
||||
self.pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
def get_cudagraph_size(self, num_tokens: int) -> int | None:
|
||||
|
||||
Reference in New Issue
Block a user