[compile] Consistent compiler config for saved/loaded vllm backends. (#35810)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen
2026-03-05 15:08:12 -05:00
committed by GitHub
parent a911f4dd20
commit a97954b6a8
2 changed files with 65 additions and 8 deletions

View File

@@ -14,6 +14,7 @@ import pytest
import torch
import vllm.model_executor.layers.activation
from vllm.compilation.backends import VllmBackend
from vllm.compilation.caching import (
StandaloneCompiledArtifacts,
VllmSerializableFunction,
@@ -721,3 +722,44 @@ class TestStandaloneCompiledArtifactsIntegration:
("mod3", "shape3"),
]:
assert cache.get(submod, shape) == shared_data
def test_functorch_config(self):
vllm_config = make_vllm_config()
example_inputs = (torch.randn(10, 10),)
def add_1(x: torch.Tensor):
return x + 1
gm = torch._dynamo.functional_export.dynamo_graph_capture_for_export(add_1)(
*example_inputs
)
gm.graph._codegen = torch.fx.graph.CodeGen()
gm._dynamo_bytecode_flatten = None
gm._dynamo_bytecode_unflatten = None
with (
torch._functorch.config.patch(bundled_autograd_cache=False),
set_current_vllm_config(vllm_config),
):
with torch._functorch.config.patch(bundled_autograd_cache=True):
fn = VllmSerializableFunction(gm, example_inputs, "", add_1)
payload = VllmSerializableFunction.serialize_compile_artifacts(fn)
config = None
def backend(*args, **kwargs) -> VllmSerializableFunction:
nonlocal config
# bundled_autograd_cache should be True even compiler backend
# runs with bundled_autograd_cache=False in ambient context.
config = torch._functorch.config.save_config_portable()
return fn
loaded_fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
with patch.object(VllmBackend, "__call__", backend):
loaded_fn(*example_inputs)
assert isinstance(config, dict)
assert "bundled_autograd_cache" in config
assert config["bundled_autograd_cache"] is True