[torch.compile] PyTorch 2.6 and nightly compatibility (#12393)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-02-07 01:09:07 +08:00
committed by GitHub
parent 85ac82d228
commit 09b95e36ab
8 changed files with 489 additions and 316 deletions

View File

@@ -2,6 +2,7 @@
from typing import Any, Dict, List
import torch
from torch import fx as fx
from vllm.config import CompilationConfig
@@ -15,7 +16,17 @@ from .reshapes import RedundantReshapesPass
logger = init_logger(__name__)
class PostGradPassManager:
class PlaceHolder:
pass
if torch.__version__ < "2.6":
Parent = PlaceHolder # type: ignore
else:
Parent = torch._inductor.custom_graph_pass.CustomGraphPass # type: ignore
class PostGradPassManager(Parent):
"""
The pass manager for post-grad passes.
It handles configuration, adding custom passes, and running passes.
@@ -55,6 +66,9 @@ class PostGradPassManager:
assert isinstance(pass_, InductorPass)
self.passes.append(pass_)
def uuid(self):
return self.__getstate__()
def __getstate__(self) -> Dict[str, List[Any]]:
"""
Custom pickling for the pass manager, as some passes cannot be pickled.