[compile] Fix caching error over pytree slice node. (#35308)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen
2026-02-27 14:34:16 -05:00
committed by GitHub
parent b1d9f5372d
commit 29b35477b0
2 changed files with 40 additions and 2 deletions

View File

@@ -16,6 +16,7 @@ import torch
import vllm.model_executor.layers.activation import vllm.model_executor.layers.activation
from vllm.compilation.caching import ( from vllm.compilation.caching import (
StandaloneCompiledArtifacts, StandaloneCompiledArtifacts,
VllmSerializableFunction,
) )
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
@@ -156,6 +157,26 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
assert torch.allclose(ret, expected) assert torch.allclose(ret, expected)
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
def foo(x: torch.Tensor):
return x[slice(0, x.shape[0])]
vllm_config = make_vllm_config()
example_input = torch.randn(10, 10)
torch._dynamo.mark_dynamic(example_input, 0)
gm = torch.fx.symbolic_trace(foo)
assert "getitem_1 = x[slice(0, getitem, None)]" in gm.code
with use_vllm_config(vllm_config):
payload = VllmSerializableFunction.serialize_compile_artifacts(
VllmSerializableFunction(gm, (example_input,), "", foo)
)
fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
assert gm.code == fn.graph_module.code
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") @pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch): def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch):
""" """

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import hashlib import hashlib
import inspect import inspect
import os import os
@@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts:
self.loaded_submodule_store = {} self.loaded_submodule_store = {}
@contextlib.contextmanager
def patch_pytree_map_over_slice():
pytree._private_register_pytree_node(
slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x)
)
try:
yield
finally:
pytree._deregister_pytree_node(slice)
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
""" """
A wrapper around a compiled function by vllm. It will forward the tensor A wrapper around a compiled function by vllm. It will forward the tensor
@@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
lambda inp: torch.empty_like(inp, device="meta"), lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"], state["example_inputs"],
) )
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): with (
patch.object(GraphPickler, "reducer_override", _graph_reducer_override),
patch_pytree_map_over_slice(),
):
state["graph_module"] = GraphPickler.dumps( state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None) state["graph_module"], Options(ops_filter=None)
) )
@@ -261,6 +277,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = pickle.loads(data) state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv()) fake_mode = FakeTensorMode(shape_env=ShapeEnv())
with patch_pytree_map_over_slice():
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile() state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode) state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)