[torch.compile] Inductor code caching fix (#10273)
Signed-off-by: luka <luka@neuralmagic.com> Signed-off-by: Luka Govedic <luka.govedic@gmail.com>
This commit is contained in:
@@ -1,38 +1,84 @@
|
||||
import hashlib
|
||||
import inspect
|
||||
import types
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import CompilationConfig
|
||||
# yapf: disable
|
||||
from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_world_size as get_tp_world_size)
|
||||
from vllm.distributed import model_parallel_is_initialized as p_is_init
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
from torch import fx
|
||||
|
||||
|
||||
class InductorPass(ABC):
|
||||
"""
|
||||
General custom inductor pass interface.
|
||||
TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
"""
|
||||
Execute the pass on the given graph.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __init__(self, config: CompilationConfig):
|
||||
self.config = config
|
||||
def uuid(self) -> Any:
|
||||
"""
|
||||
Provide a unique identifier for the pass, used in Inductor code cache.
|
||||
This should depend on the pass implementation, so that changes to the
|
||||
pass result in recompilation.
|
||||
By default, the object source is hashed.
|
||||
"""
|
||||
return InductorPass.hash_source(self)
|
||||
|
||||
def dump_graph(self, graph: torch.fx.Graph, stage: str):
|
||||
if stage in self.config.dump_graph_stages:
|
||||
# Make sure filename includes rank in the distributed setting
|
||||
parallel = p_is_init() and get_tp_world_size() > 1
|
||||
rank = f"-{get_tp_rank()}" if parallel else ""
|
||||
filepath = self.config.dump_graph_dir / f"{stage}{rank}.py"
|
||||
@staticmethod
|
||||
def hash_source(*srcs: Union[str, Any]):
|
||||
"""
|
||||
Utility method to hash the sources of functions or objects.
|
||||
:param srcs: strings or objects to add to the hash.
|
||||
Objects and functions have their source inspected.
|
||||
:return:
|
||||
"""
|
||||
hasher = hashlib.sha256()
|
||||
for src in srcs:
|
||||
if isinstance(src, str):
|
||||
src_str = src
|
||||
elif isinstance(src, types.FunctionType):
|
||||
src_str = inspect.getsource(src)
|
||||
else:
|
||||
src_str = inspect.getsource(src.__class__)
|
||||
hasher.update(src_str.encode("utf-8"))
|
||||
return hasher.digest()
|
||||
|
||||
logger.info("Printing graph to %s", filepath)
|
||||
with open(filepath, "w") as f:
|
||||
src = graph.python_code(root_module="self", verbose=True).src
|
||||
# Add imports so it's not full of errors
|
||||
print("import torch; from torch import device", file=f)
|
||||
print(src, file=f)
|
||||
|
||||
class CallableInductorPass(InductorPass):
|
||||
"""
|
||||
This class is a wrapper for a callable that automatically provides an
|
||||
implementation of the UUID.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
callable: Callable[[fx.Graph], None],
|
||||
uuid: Optional[Any] = None):
|
||||
self.callable = callable
|
||||
if uuid is None:
|
||||
uuid = InductorPass.hash_source(callable)
|
||||
self._uuid = uuid
|
||||
|
||||
def __call__(self, graph: torch.fx.Graph):
|
||||
self.callable(graph)
|
||||
|
||||
def uuid(self) -> Any:
|
||||
return self._uuid
|
||||
|
||||
def __getstate__(self):
|
||||
"""
|
||||
Pickling occurs in the Inductor code cache if a pass is not given to
|
||||
the pass manager but is instead directly added to config as a pass.
|
||||
See PostGradPassManager for more.
|
||||
|
||||
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
|
||||
"""
|
||||
return self._uuid
|
||||
|
||||
def __setstate__(self, state):
|
||||
raise ValueError("Cannot unpickle CallableInductorPass")
|
||||
|
||||
Reference in New Issue
Block a user