From 29b35477b0661f527d2b951ff5125f5c58fce3fe Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Fri, 27 Feb 2026 14:34:16 -0500 Subject: [PATCH] [compile] Fix caching error over pytree slice node. (#35308) Signed-off-by: zhxchen17 --- tests/compile/test_aot_compile.py | 21 +++++++++++++++++++++ vllm/compilation/caching.py | 21 +++++++++++++++++++-- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_aot_compile.py b/tests/compile/test_aot_compile.py index fbacbb6bf..4cfdc1b2e 100644 --- a/tests/compile/test_aot_compile.py +++ b/tests/compile/test_aot_compile.py @@ -16,6 +16,7 @@ import torch import vllm.model_executor.layers.activation from vllm.compilation.caching import ( StandaloneCompiledArtifacts, + VllmSerializableFunction, ) from vllm.compilation.decorators import support_torch_compile from vllm.config import ( @@ -156,6 +157,26 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch): assert torch.allclose(ret, expected) +@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") +def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch): + def foo(x: torch.Tensor): + return x[slice(0, x.shape[0])] + + vllm_config = make_vllm_config() + + example_input = torch.randn(10, 10) + torch._dynamo.mark_dynamic(example_input, 0) + gm = torch.fx.symbolic_trace(foo) + assert "getitem_1 = x[slice(0, getitem, None)]" in gm.code + with use_vllm_config(vllm_config): + payload = VllmSerializableFunction.serialize_compile_artifacts( + VllmSerializableFunction(gm, (example_input,), "", foo) + ) + fn = VllmSerializableFunction.deserialize_compile_artifacts(payload) + + assert gm.code == fn.graph_module.code + + @pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10") def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch): """ diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 07f9db419..3917a4f28 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import hashlib import inspect import os @@ -144,6 +145,18 @@ class StandaloneCompiledArtifacts: self.loaded_submodule_store = {} +@contextlib.contextmanager +def patch_pytree_map_over_slice(): + pytree._private_register_pytree_node( + slice, lambda x: ([x.start, x.stop, x.step], None), lambda x, c: slice(*x) + ) + + try: + yield + finally: + pytree._deregister_pytree_node(slice) + + class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] """ A wrapper around a compiled function by vllm. It will forward the tensor @@ -235,7 +248,10 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] lambda inp: torch.empty_like(inp, device="meta"), state["example_inputs"], ) - with patch.object(GraphPickler, "reducer_override", _graph_reducer_override): + with ( + patch.object(GraphPickler, "reducer_override", _graph_reducer_override), + patch_pytree_map_over_slice(), + ): state["graph_module"] = GraphPickler.dumps( state["graph_module"], Options(ops_filter=None) ) @@ -261,7 +277,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] state = pickle.loads(data) fake_mode = FakeTensorMode(shape_env=ShapeEnv()) - state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) + with patch_pytree_map_over_slice(): + state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode) state["graph_module"].recompile() state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)