[Fix] [torch.compile] Improve UUID system for custom passes (#15249)
Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
@@ -1,26 +1,29 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import hashlib
|
||||
import importlib.metadata
|
||||
import inspect
|
||||
import json
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from torch import fx
|
||||
|
||||
if Version(importlib.metadata.version('torch')) >= Version("2.6"):
|
||||
from torch._inductor.custom_graph_pass import CustomGraphPass
|
||||
else:
|
||||
# CustomGraphPass is not present in 2.5 or lower, import our version
|
||||
from .torch25_custom_graph_pass import ( # noqa: yapf
|
||||
Torch25CustomGraphPass as CustomGraphPass)
|
||||
|
||||
class InductorPass(ABC):
|
||||
"""
|
||||
General custom inductor pass interface.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
"""
|
||||
Execute the pass on the given graph.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
class InductorPass(CustomGraphPass):
|
||||
"""
|
||||
A custom graph pass that uses a hash of its source as the UUID.
|
||||
This is defined as a convenience and should work in most cases.
|
||||
"""
|
||||
|
||||
def uuid(self) -> Any:
|
||||
"""
|
||||
@@ -48,7 +51,16 @@ class InductorPass(ABC):
|
||||
else:
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.digest()
|
||||
return hasher.hexdigest()
|
||||
|
||||
@staticmethod
|
||||
def hash_dict(dict_: Dict[Any, Any]):
|
||||
"""
|
||||
Utility method to hash a dictionary, can alternatively be used for uuid.
|
||||
:return: A sha256 hash of the json rep of the dictionary.
|
||||
"""
|
||||
encoded = json.dumps(dict_, sort_keys=True).encode("utf-8")
|
||||
return hashlib.sha256(encoded).hexdigest()
|
||||
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
@@ -61,25 +73,10 @@ class CallableInductorPass(InductorPass):
|
||||
callable: Callable[[fx.Graph], None],
|
||||
uuid: Optional[Any] = None):
|
||||
self.callable = callable
|
||||
if uuid is None:
|
||||
uuid = InductorPass.hash_source(callable)
|
||||
self._uuid = uuid
|
||||
self._uuid = self.hash_source(callable) if uuid is None else uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Pickling occurs in the Inductor code cache if a pass is not given to
|
||||
the pass manager but is instead directly added to config as a pass.
|
||||
See PostGradPassManager for more.
|
||||
|
||||
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
|
||||
"""
|
||||
return self._uuid
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise ValueError("Cannot unpickle CallableInductorPass")
|
||||
|
||||
Reference in New Issue
Block a user