[compile] Fix aot test failures with torch 2.12. (#37604)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen
2026-03-20 12:06:29 -04:00
committed by GitHub
parent aa84e43ccb
commit c0f5fae601
2 changed files with 54 additions and 35 deletions

View File

@@ -14,6 +14,7 @@ from unittest.mock import Mock, patch
import pytest
import torch
import vllm.envs as envs
import vllm.model_executor.layers.activation
from vllm.compilation.backends import VllmBackend
from vllm.compilation.caching import (
@@ -162,6 +163,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
from torch._subclasses import FakeTensorMode
from torch.fx.experimental.symbolic_shapes import ShapeEnv
def foo(x: torch.Tensor):
return x[slice(0, x.shape[0])]
@@ -172,12 +176,13 @@ def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
gm = torch.fx.symbolic_trace(foo)
assert "getitem_1 = x[slice(0, getitem, None)]" in gm.code
with use_vllm_config(vllm_config):
payload = VllmSerializableFunction.serialize_compile_artifacts(
VllmSerializableFunction(gm, (example_input,), "", foo)
payload = VllmSerializableFunction.serialize_graph_module(gm)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
loaded_gm = VllmSerializableFunction.deserialize_graph_module(
payload, fake_mode
)
fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
assert gm.code == fn.graph_module.code
assert gm.code == loaded_gm.code
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
@@ -725,6 +730,10 @@ class TestStandaloneCompiledArtifactsIntegration:
]:
assert cache.get(submod, shape) == shared_data
@pytest.mark.skipif(
envs.VLLM_USE_MEGA_AOT_ARTIFACT,
reason="There's no AOT Autograd run with mega artifact",
)
def test_functorch_config(self):
vllm_config = make_vllm_config()
example_inputs = (torch.randn(10, 10),)

View File

@@ -11,6 +11,8 @@ from typing import Any, Literal
from unittest.mock import patch
import torch
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
from torch.utils import _pytree as pytree
import vllm.envs as envs
@@ -206,26 +208,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return self.optimized_call(*args, **kwargs)
@classmethod
def serialize_compile_artifacts(
cls, compiled_fn: "VllmSerializableFunction"
) -> bytes:
def serialize_graph_module(cls, graph_module: torch.fx.GraphModule) -> bytes:
import sympy
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
state.pop("vllm_backend", None)
state.pop("_fake_mode", None)
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
graph_reducer_override = GraphPickler.reducer_override
@@ -242,6 +226,37 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return type(None), ()
return graph_reducer_override(self, obj)
with (
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
patch_pytree_map_over_slice(),
):
return GraphPickler.dumps(graph_module, Options(ops_filter=None))
@classmethod
def deserialize_graph_module(
cls, data: bytes, fake_mode: FakeTensorMode
) -> torch.fx.GraphModule:
with patch_pytree_map_over_slice():
return GraphPickler.loads(data, fake_mode)
@classmethod
def serialize_compile_artifacts(
cls, compiled_fn: "VllmSerializableFunction"
) -> bytes:
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
state.pop("vllm_backend", None)
state.pop("_fake_mode", None)
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
if state.get("sym_tensor_indices"):
# put tensor inputs on meta device since their data
# isn't needed, yet we need the meta for make_copy_and_call
@@ -257,14 +272,9 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
with (
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
patch_pytree_map_over_slice(),
):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
state["graph_module"] = cls.serialize_graph_module(state["graph_module"])
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
if compiled_fn.vllm_backend:
(
@@ -280,14 +290,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
from torch._guards import TracingContext, tracing
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
with patch_pytree_map_over_slice():
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"] = cls.deserialize_graph_module(
state["graph_module"], fake_mode
)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)