[torch.compile] rework compile control with piecewise cudagraph (#9715)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -1479,6 +1479,15 @@ class LazyDict(Mapping, Generic[T]):
|
||||
return len(self._factory)
|
||||
|
||||
|
||||
def combine_fx_passes(passes: List[Callable]) -> Callable:
|
||||
|
||||
def combined_fx(graph) -> None:
|
||||
for fx in passes:
|
||||
fx(graph)
|
||||
|
||||
return combined_fx
|
||||
|
||||
|
||||
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Create a weak reference to a tensor.
|
||||
@@ -1486,3 +1495,19 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
but will not keep the original tensor alive.
|
||||
"""
|
||||
return torch.ops._C.weak_ref_tensor(tensor)
|
||||
|
||||
|
||||
def weak_ref_tensors(
|
||||
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
|
||||
"""
|
||||
Convenience function to create weak references to tensors,
|
||||
for single tensor, list of tensors or tuple of tensors.
|
||||
"""
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
return weak_ref_tensor(tensors)
|
||||
if isinstance(tensors, list):
|
||||
return [weak_ref_tensor(t) for t in tensors]
|
||||
if isinstance(tensors, tuple):
|
||||
return tuple(weak_ref_tensor(t) for t in tensors)
|
||||
raise ValueError("Invalid type for tensors")
|
||||
|
||||
Reference in New Issue
Block a user