[compile] Fix aot test failures with torch 2.12. (#37604)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
import vllm.model_executor.layers.activation
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.caching import (
|
||||
@@ -162,6 +163,9 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
def foo(x: torch.Tensor):
|
||||
return x[slice(0, x.shape[0])]
|
||||
|
||||
@@ -172,12 +176,13 @@ def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
|
||||
gm = torch.fx.symbolic_trace(foo)
|
||||
assert "getitem_1 = x[slice(0, getitem, None)]" in gm.code
|
||||
with use_vllm_config(vllm_config):
|
||||
payload = VllmSerializableFunction.serialize_compile_artifacts(
|
||||
VllmSerializableFunction(gm, (example_input,), "", foo)
|
||||
payload = VllmSerializableFunction.serialize_graph_module(gm)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
loaded_gm = VllmSerializableFunction.deserialize_graph_module(
|
||||
payload, fake_mode
|
||||
)
|
||||
fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
|
||||
|
||||
assert gm.code == fn.graph_module.code
|
||||
assert gm.code == loaded_gm.code
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
@@ -725,6 +730,10 @@ class TestStandaloneCompiledArtifactsIntegration:
|
||||
]:
|
||||
assert cache.get(submod, shape) == shared_data
|
||||
|
||||
@pytest.mark.skipif(
|
||||
envs.VLLM_USE_MEGA_AOT_ARTIFACT,
|
||||
reason="There's no AOT Autograd run with mega artifact",
|
||||
)
|
||||
def test_functorch_config(self):
|
||||
vllm_config = make_vllm_config()
|
||||
example_inputs = (torch.randn(10, 10),)
|
||||
|
||||
@@ -11,6 +11,8 @@ from typing import Any, Literal
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler, Options
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import vllm.envs as envs
|
||||
@@ -206,26 +208,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
return self.optimized_call(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, compiled_fn: "VllmSerializableFunction"
|
||||
) -> bytes:
|
||||
def serialize_graph_module(cls, graph_module: torch.fx.GraphModule) -> bytes:
|
||||
import sympy
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler, Options
|
||||
|
||||
state = compiled_fn.__dict__.copy()
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
state.pop("vllm_backend", None)
|
||||
state.pop("_fake_mode", None)
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
for name, submod in state["graph_module"].named_children():
|
||||
if hasattr(submod, "graph"):
|
||||
for node in submod.graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
|
||||
graph_reducer_override = GraphPickler.reducer_override
|
||||
|
||||
@@ -242,6 +226,37 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
return type(None), ()
|
||||
return graph_reducer_override(self, obj)
|
||||
|
||||
with (
|
||||
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
|
||||
patch_pytree_map_over_slice(),
|
||||
):
|
||||
return GraphPickler.dumps(graph_module, Options(ops_filter=None))
|
||||
|
||||
@classmethod
|
||||
def deserialize_graph_module(
|
||||
cls, data: bytes, fake_mode: FakeTensorMode
|
||||
) -> torch.fx.GraphModule:
|
||||
with patch_pytree_map_over_slice():
|
||||
return GraphPickler.loads(data, fake_mode)
|
||||
|
||||
@classmethod
|
||||
def serialize_compile_artifacts(
|
||||
cls, compiled_fn: "VllmSerializableFunction"
|
||||
) -> bytes:
|
||||
state = compiled_fn.__dict__.copy()
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
state.pop("vllm_backend", None)
|
||||
state.pop("_fake_mode", None)
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
for name, submod in state["graph_module"].named_children():
|
||||
if hasattr(submod, "graph"):
|
||||
for node in submod.graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
|
||||
if state.get("sym_tensor_indices"):
|
||||
# put tensor inputs on meta device since their data
|
||||
# isn't needed, yet we need the meta for make_copy_and_call
|
||||
@@ -257,14 +272,9 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
with (
|
||||
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
|
||||
patch_pytree_map_over_slice(),
|
||||
):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||
|
||||
state["graph_module"] = cls.serialize_graph_module(state["graph_module"])
|
||||
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||
|
||||
if compiled_fn.vllm_backend:
|
||||
(
|
||||
@@ -280,14 +290,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
@classmethod
|
||||
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
|
||||
from torch._guards import TracingContext, tracing
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
with patch_pytree_map_over_slice():
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
|
||||
state["graph_module"] = cls.deserialize_graph_module(
|
||||
state["graph_module"], fake_mode
|
||||
)
|
||||
state["graph_module"].recompile()
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user