[compile] Consistent compiler config for saved/loaded vllm backends. (#35810)
Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
@@ -14,6 +14,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.activation
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.caching import (
|
||||
StandaloneCompiledArtifacts,
|
||||
VllmSerializableFunction,
|
||||
@@ -721,3 +722,44 @@ class TestStandaloneCompiledArtifactsIntegration:
|
||||
("mod3", "shape3"),
|
||||
]:
|
||||
assert cache.get(submod, shape) == shared_data
|
||||
|
||||
def test_functorch_config(self):
|
||||
vllm_config = make_vllm_config()
|
||||
example_inputs = (torch.randn(10, 10),)
|
||||
|
||||
def add_1(x: torch.Tensor):
|
||||
return x + 1
|
||||
|
||||
gm = torch._dynamo.functional_export.dynamo_graph_capture_for_export(add_1)(
|
||||
*example_inputs
|
||||
)
|
||||
|
||||
gm.graph._codegen = torch.fx.graph.CodeGen()
|
||||
gm._dynamo_bytecode_flatten = None
|
||||
gm._dynamo_bytecode_unflatten = None
|
||||
|
||||
with (
|
||||
torch._functorch.config.patch(bundled_autograd_cache=False),
|
||||
set_current_vllm_config(vllm_config),
|
||||
):
|
||||
with torch._functorch.config.patch(bundled_autograd_cache=True):
|
||||
fn = VllmSerializableFunction(gm, example_inputs, "", add_1)
|
||||
|
||||
payload = VllmSerializableFunction.serialize_compile_artifacts(fn)
|
||||
|
||||
config = None
|
||||
|
||||
def backend(*args, **kwargs) -> VllmSerializableFunction:
|
||||
nonlocal config
|
||||
# bundled_autograd_cache should be True even compiler backend
|
||||
# runs with bundled_autograd_cache=False in ambient context.
|
||||
config = torch._functorch.config.save_config_portable()
|
||||
return fn
|
||||
|
||||
loaded_fn = VllmSerializableFunction.deserialize_compile_artifacts(payload)
|
||||
with patch.object(VllmBackend, "__call__", backend):
|
||||
loaded_fn(*example_inputs)
|
||||
|
||||
assert isinstance(config, dict)
|
||||
assert "bundled_autograd_cache" in config
|
||||
assert config["bundled_autograd_cache"] is True
|
||||
|
||||
@@ -178,6 +178,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
is_encoder: bool = False,
|
||||
vllm_backend: Any | None = None,
|
||||
sym_tensor_indices: list[int] | None = None,
|
||||
aot_autograd_config: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
self.graph_module = graph_module
|
||||
@@ -188,6 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
self.shape_env = None
|
||||
self.vllm_backend = vllm_backend
|
||||
self.sym_tensor_indices = sym_tensor_indices
|
||||
|
||||
import torch._functorch.config as functorch_config
|
||||
|
||||
self.aot_autograd_config = (
|
||||
aot_autograd_config or functorch_config.save_config_portable()
|
||||
)
|
||||
|
||||
sym_input = next(
|
||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||
)
|
||||
@@ -286,6 +294,12 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
|
||||
returns_tuple_map = state.pop("returns_tuple_map", {})
|
||||
|
||||
saved_aot_autograd_config = state["aot_autograd_config"]
|
||||
if saved_aot_autograd_config is not None:
|
||||
functorch_ctx = torch._functorch.config.patch(saved_aot_autograd_config)
|
||||
else:
|
||||
functorch_ctx = contextlib.nullcontext()
|
||||
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
assert standalone_compile_artifacts is not None
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
@@ -299,13 +313,14 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
num_submods,
|
||||
)
|
||||
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
standalone_compile_artifacts=standalone_compile_artifacts,
|
||||
vllm_config=get_current_vllm_config(),
|
||||
sym_shape_indices_map=sym_shape_indices_map,
|
||||
returns_tuple_map=returns_tuple_map,
|
||||
)
|
||||
with functorch_ctx:
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
standalone_compile_artifacts=standalone_compile_artifacts,
|
||||
vllm_config=get_current_vllm_config(),
|
||||
sym_shape_indices_map=sym_shape_indices_map,
|
||||
returns_tuple_map=returns_tuple_map,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"reconstructed serializable fn from standalone compile artifacts"
|
||||
@@ -328,7 +343,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
vllm_backend: VllmBackend = VllmBackend(
|
||||
vllm_config, state["prefix"], is_encoder
|
||||
)
|
||||
with tracing(TracingContext(fake_mode)):
|
||||
with tracing(TracingContext(fake_mode)), functorch_ctx:
|
||||
fn.optimized_call = vllm_backend(
|
||||
state["graph_module"], compile_inputs
|
||||
).optimized_call
|
||||
|
||||
Reference in New Issue
Block a user