[Fix] [torch.compile] Improve UUID system for custom passes (#15249)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-03-23 21:54:07 -04:00
committed by GitHub
parent dccf535f8e
commit f622dbcf39
5 changed files with 132 additions and 91 deletions

View File

@@ -1,8 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List
from typing import List
import torch
from torch import fx as fx
from vllm.config import CompilationConfig
@@ -10,29 +9,18 @@ from vllm.logger import init_logger
from .fix_functionalization import FixFunctionalizationPass
from .fusion import FusionPass
from .inductor_pass import InductorPass
from .inductor_pass import CustomGraphPass, InductorPass
from .noop_elimination import NoOpEliminationPass
logger = init_logger(__name__)
class PlaceHolder:
pass
if torch.__version__ < "2.6":
Parent = PlaceHolder # type: ignore
else:
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
class PostGradPassManager(Parent):
class PostGradPassManager(CustomGraphPass):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
It also supports pickling, which is used by the Inductor code cache.
TODO(torch==2.6), use CustomGraphPass
(torch._inductor.custom_graph_pass.CustomGraphPass)
It supports uuid for the Inductor code cache. That includes torch<2.6
support using pickling (in .inductor_pass.CustomGraphPass).
The order of the post-grad post-passes is:
1. passes (constructor parameter)
@@ -67,27 +55,13 @@ class PostGradPassManager(Parent):
self.passes.append(pass_)
def uuid(self):
return self.__getstate__()
def __getstate__(self) -> Dict[str, List[Any]]:
"""
Custom pickling for the pass manager, as some passes cannot be pickled.
Pickling occurs because the pass manager is set as the value of
`config["post_grad_custom_post_pass"]` in the Inductor config.
The config is pickled to act as a key in the Inductor code cache.
Any other passes in the config are pickled as well.
TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead.
The PostGradPassManager is set as a custom pass in the Inductor and
affects compilation caching. Its uuid depends on the UUIDs of all
dependent passes and the pass config. See InductorPass for more info.
"""
state = {"pass_config": self.pass_config.uuid(), "passes": []}
for pass_ in self.passes:
state["passes"].append(pass_.uuid())
state["passes"].append(self.fix_functionalization.uuid())
return state
def __setstate__(self, state):
"""
Do not allow unpickling of the pass manager.
If this is needed in the future, it should properly pickle the passes.
"""
raise ValueError("Cannot unpickle PostGradPassManager")
return InductorPass.hash_dict(state)