[compile] Consistent compiler config for saved/loaded vllm backends. (#35810)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen
2026-03-05 15:08:12 -05:00
committed by GitHub
parent a911f4dd20
commit a97954b6a8
2 changed files with 65 additions and 8 deletions

View File

@@ -14,6 +14,7 @@ import pytest
import torch
import vllm.model_executor.layers.activation
from vllm.compilation.backends import VllmBackend
from vllm.compilation.caching import (
StandaloneCompiledArtifacts,
VllmSerializableFunction,
@@ -721,3 +722,44 @@ class TestStandaloneCompiledArtifactsIntegration:
("mod3", "shape3"),
]:
assert cache.get(submod, shape) == shared_data
def test_functorch_config(self):
vllm_config = make_vllm_config()
example_inputs = (torch.randn(10, 10),)
def add_1(x: torch.Tensor):
return x + 1
gm = torch._dynamo.functional_export.dynamo_graph_capture_for_export(add_1)(
*example_inputs
)
gm.graph._codegen = torch.fx.graph.CodeGen()
gm._dynamo_bytecode_flatten = None
gm._dynamo_bytecode_unflatten = None
with (
torch._functorch.config.patch(bundled_autograd_cache=False),
set_current_vllm_config(vllm_config),
):
with torch._functorch.config.patch(bundled_autograd_cache=True):
fn = VllmSerializableFunction(gm, example_inputs, "", add_1)
payload = VllmSerializableFunction.serialize_compile_artifacts(fn)
config = None
def backend(*args, **kwargs) -> VllmSerializableFunction:
nonlocal config
# bundled_autograd_cache should be True even compiler backend
# runs with bundled_autograd_cache=False in ambient context.
config = torch._functorch.config.save_config_portable()
return fn
loaded_fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
with patch.object(VllmBackend, "__call__", backend):
loaded_fn(*example_inputs)
assert isinstance(config, dict)
assert "bundled_autograd_cache" in config
assert config["bundled_autograd_cache"] is True

View File

@@ -178,6 +178,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
is_encoder: bool = False,
vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None,
aot_autograd_config: dict[str, Any] | None = None,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
@@ -188,6 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
self.shape_env = None
self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices
import torch._functorch.config as functorch_config
self.aot_autograd_config = (
aot_autograd_config or functorch_config.save_config_portable()
)
sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
)
@@ -286,6 +294,12 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
returns_tuple_map = state.pop("returns_tuple_map", {})
saved_aot_autograd_config = state["aot_autograd_config"]
if saved_aot_autograd_config is not None:
functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config)
else:
functorch_ctx = contextlib.nullcontext()
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
assert standalone_compile_artifacts is not None
submod_names = standalone_compile_artifacts.submodule_names()
@@ -299,13 +313,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
num_submods,
)
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)
with functorch_ctx:
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)
logger.info(
"reconstructed serializable fn from standalone compile artifacts"
@@ -328,7 +343,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
vllm_backend: VllmBackend = VllmBackend(
vllm_config, state["prefix"], is_encoder
)
with tracing(TracingContext(fake_mode)):
with tracing(TracingContext(fake_mode)), functorch_ctx:
fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs
).optimized_call