[compile] Fix caching error over pytree slice node. (#35308)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import torch
|
||||
import vllm.model_executor.layers.activation
|
||||
from vllm.compilation.caching import (
|
||||
StandaloneCompiledArtifacts,
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
@@ -156,6 +157,26 @@ def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
|
||||
assert torch.allclose(ret, expected)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_save_and_load_slice(monkeypatch: pytest.MonkeyPatch):
|
||||
def foo(x: torch.Tensor):
|
||||
return x[slice(0, x.shape[0])]
|
||||
|
||||
vllm_config = make_vllm_config()
|
||||
|
||||
example_input = torch.randn(10, 10)
|
||||
torch._dynamo.mark_dynamic(example_input, 0)
|
||||
gm = torch.fx.symbolic_trace(foo)
|
||||
assert "getitem_1 = x[slice(0, getitem, None)]" in gm.code
|
||||
with use_vllm_config(vllm_config):
|
||||
payload = VllmSerializableFunction.serialize_compile_artifacts(
|
||||
VllmSerializableFunction(gm, (example_input,), "", foo)
|
||||
)
|
||||
fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
|
||||
|
||||
assert gm.code == fn.graph_module.code
|
||||
|
||||
|
||||
@pytest.mark.skipif(not is_torch_equal_or_newer("2.10.0"), reason="requires torch 2.10")
|
||||
def test_cache_load_returns_tuple_consistency(monkeypatch: pytest.MonkeyPatch):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user