[Fix] [torch.compile] Improve UUID system for custom passes (#15249)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-03-23 21:54:07 -04:00
committed by GitHub
parent dccf535f8e
commit f622dbcf39
5 changed files with 132 additions and 91 deletions

View File

@@ -1,26 +1,29 @@
# SPDX-License-Identifier: Apache-2.0
import hashlib
import importlib.metadata
import inspect
import json
import types
from abc import ABC, abstractmethod
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union
import torch
from packaging.version import Version
from torch import fx
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
from torch._inductor.custom_graph_pass import CustomGraphPass
else:
# CustomGraphPass is not present in 2.5 or lower, import our version
from .torch25_custom_graph_pass import ( # noqa: yapf
Torch25CustomGraphPass as CustomGraphPass)
class InductorPass(ABC):
"""
General custom inductor pass interface.
"""
@abstractmethod
def __call__(self, graph: torch.fx.Graph):
"""
Execute the pass on the given graph.
"""
raise NotImplementedError
class InductorPass(CustomGraphPass):
"""
A custom graph pass that uses a hash of its source as the UUID.
This is defined as a convenience and should work in most cases.
"""
def uuid(self) -> Any:
"""
@@ -48,7 +51,16 @@ class InductorPass(ABC):
else:
src_str = inspect.getsource(src.__class__)
hasher.update(src_str.encode("utf-8"))
return hasher.digest()
return hasher.hexdigest()
@staticmethod
def hash_dict(dict_: Dict[Any, Any]):
"""
Utility method to hash a dictionary, can alternatively be used for uuid.
:return: A sha256 hash of the json rep of the dictionary.
"""
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
return hashlib.sha256(encoded).hexdigest()
class CallableInductorPass(InductorPass):
@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
callable: Callable[[fx.Graph], None],
uuid: Optional[Any] = None):
self.callable = callable
if uuid is None:
uuid = InductorPass.hash_source(callable)
self._uuid = uuid
self._uuid = self.hash_source(callable) if uuid is None else uuid
def __call__(self, graph: torch.fx.Graph):
self.callable(graph)
def uuid(self) -> Any:
return self._uuid
def __getstate__(self):
"""
Pickling occurs in the Inductor code cache if a pass is not given to
the pass manager but is instead directly added to config as a pass.
See PostGradPassManager for more.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
"""
return self._uuid
def __setstate__(self, state):
raise ValueError("Cannot unpickle CallableInductorPass")