[torch.compile] rework compile control with piecewise cudagraph (#9715)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-10-29 23:03:49 -07:00
committed by GitHub
parent 7b0365efef
commit ff5ed6e1bc
17 changed files with 979 additions and 102 deletions

View File

@@ -1479,6 +1479,15 @@ class LazyDict(Mapping, Generic[T]):
return len(self._factory)
def combine_fx_passes(passes: List[Callable]) -> Callable:
def combined_fx(graph) -> None:
for fx in passes:
fx(graph)
return combined_fx
def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""
Create a weak reference to a tensor.
@@ -1486,3 +1495,19 @@ def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor:
but will not keep the original tensor alive.
"""
return torch.ops._C.weak_ref_tensor(tensor)
def weak_ref_tensors(
tensors: Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor]]:
"""
Convenience function to create weak references to tensors,
for single tensor, list of tensors or tuple of tensors.
"""
if isinstance(tensors, torch.Tensor):
return weak_ref_tensor(tensors)
if isinstance(tensors, list):
return [weak_ref_tensor(t) for t in tensors]
if isinstance(tensors, tuple):
return tuple(weak_ref_tensor(t) for t in tensors)
raise ValueError("Invalid type for tensors")