[AOT compilation] support torch.compile inductor artifacts in VllmCompiledFunction (#25205)

Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com>
This commit is contained in:
dolpm
2026-01-20 11:45:59 -08:00
committed by GitHub
parent 193069d129
commit 7c5dedc247
8 changed files with 1169 additions and 113 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import contextvars
import dataclasses
import hashlib
import json
@@ -34,7 +35,6 @@ from vllm.platforms import current_platform
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.utils.torch_utils import is_torch_equal_or_newer
from .caching import VllmSerializableFunction
from .compiler_interface import (
CompilerInterface,
EagerAdaptor,
@@ -49,7 +49,48 @@ from .pass_manager import PostGradPassManager
logger = init_logger(__name__)
def make_copy_and_call(
sym_tensor_indices: list[int],
input_buffers: list[torch.Tensor | None],
callable_fn: Callable[..., Any],
) -> Callable[..., Any]:
"""Create a wrapper that copies inputs to static buffers before calling.
This is used for cudagraph input copying where we need to copy dynamic
tensors to static buffers before invoking the compiled graph.
Args:
sym_tensor_indices: Indices of tensors with symbolic shapes
input_buffers: List of static buffers (can contain None for lazy init)
callable_fn: The compiled function to call
Returns:
A wrapper function that copies inputs and calls the compiled function
"""
def copy_and_call(*args):
list_args = list(args)
for i, index in enumerate(sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
# lazy initialization of buffer on first call
if input_buffers[i] is None:
input_buffers[i] = runtime_tensor.clone()
static_tensor = input_buffers[i][:runtime_shape] # type: ignore[index]
static_tensor.copy_(runtime_tensor)
list_args[index] = static_tensor
return callable_fn(*list_args)
return copy_and_call
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
assert not envs.VLLM_USE_MEGA_AOT_ARTIFACT or envs.VLLM_USE_STANDALONE_COMPILE, (
"VLLM_USE_MEGA_AOT_ARTIFACT=1 requires VLLM_USE_STANDALONE_COMPILE=1"
)
if compilation_config.backend == "inductor":
# Use standalone compile only if requested, version is new enough,
# and the symbol actually exists in this PyTorch build.
@@ -355,6 +396,60 @@ def split_graph(
compilation_start_time = 0.0
def wrap_with_cudagraph_if_needed(
piecewise_backend: Any,
vllm_config: VllmConfig,
compilation_config: CompilationConfig,
is_first_graph: bool,
is_last_graph: bool,
) -> Any:
"""
Wrap a piecewise backend with CUDA graph wrapper if needed.
This function is shared between VllmBackend and
construct_serializable_fn_from_inductor_cache.
Args:
piecewise_backend: The backend to wrap
vllm_config: The vLLM configuration
compilation_config: The compilation configuration
is_first_graph: Whether this is the first graph in the sequence
is_last_graph: Whether this is the last graph in the sequence
Returns:
The wrapped backend if CUDA graphs are enabled, otherwise the original backend
"""
if (
not compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
or compilation_config.use_inductor_graph_partition
):
return piecewise_backend
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from .cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls()
)
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
return static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=is_first_graph,
gc_disable=not is_first_graph,
weak_ref_output=is_last_graph,
),
)
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
@@ -365,6 +460,18 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
it will be used to determine the order of the compiled piecewise
graphs. The first graph will handle logging, and the last graph
has some special cudagraph output handling.
Note: This class shares similar logic with
reconstruct_serializable_fn_from_mega_artifact in caching.py.
Both create PiecewiseBackend instances and wrap them with cudagraph.
The key difference is:
- reconstruct_serializable_fn_from_mega_artifact: PiecewiseBackend receives
pre-compiled runnables (compiled_runnables is set, graph is None)
- this class: PiecewiseBackend receives the FX graph to compile
(graph is set, compiled_runnables is None)
If modifying the backend creation/wrapping logic, consider updating both.
"""
def __init__(
@@ -413,6 +520,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
]
# Lazy import here to avoid circular import
from torch._inductor.compile_fx import graph_returns_tuple
from .piecewise_backend import PiecewiseBackend
piecewise_backend = PiecewiseBackend(
@@ -422,38 +531,16 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
len(self.compile_submod_names),
sym_shape_indices,
self.vllm_backend,
graph_returns_tuple(submod),
)
if (
self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
and not self.compilation_config.use_inductor_graph_partition
):
# We're using Dynamo-based piecewise splitting, so we wrap
# the whole subgraph with a static graph wrapper.
from .cuda_graph import CUDAGraphOptions
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
# class) as platform dependent.
static_graph_wrapper_class = resolve_obj_by_qualname(
current_platform.get_static_graph_wrapper_cls()
)
# Always assign PIECEWISE runtime mode to the
# CUDAGraphWrapper for piecewise_backend, to distinguish
# it from the FULL cudagraph runtime mode, no matter it
# is wrapped on a full or piecewise fx graph.
self.module.__dict__[target] = static_graph_wrapper_class(
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
weak_ref_output=piecewise_backend.is_last_graph,
),
)
else:
self.module.__dict__[target] = piecewise_backend
self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
piecewise_backend,
self.vllm_config,
self.compilation_config,
piecewise_backend.is_first_graph,
piecewise_backend.is_last_graph,
)
compilation_counter.num_piecewise_capturable_graphs_seen += 1
@@ -465,6 +552,21 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
model_tag: str = "backbone"
model_is_encoder: bool = False
_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = (
contextvars.ContextVar("on_compilation_complete_callback", default=None)
)
@contextmanager
def set_on_compilation_complete(
callback: Callable[[], None],
) -> Generator[None, None, None]:
token = _on_compilation_complete_callback.set(callback)
try:
yield
finally:
_on_compilation_complete_callback.reset(token)
@contextmanager
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
@@ -509,8 +611,6 @@ class VllmBackend:
returned_callable: Callable[..., Any]
# Inductor passes to run on the graph pre-defunctionalization
post_grad_passes: Sequence[Callable[..., Any]]
sym_tensor_indices: list[int]
input_buffers: list[torch.Tensor]
compiler_manager: CompilerManager
# Copy of CompilationConfig.inductor_compile_config +
# an entry for PostGradPassManager
@@ -539,9 +639,6 @@ class VllmBackend:
)()
self.pass_key = current_platform.pass_key
self.sym_tensor_indices = []
self.input_buffers = []
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
@@ -558,6 +655,68 @@ class VllmBackend:
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
def collect_standalone_compile_artifacts(
self,
) -> tuple[Any, dict[str, list[int]] | None, dict[str, bool] | None]:
"""Collect inductor cache artifacts from all piecewise backends.
Returns:
tuple: (standalone_compile_artifacts, sym_shape_indices_map,
returns_tuple_map)
- standalone_compile_artifacts: StandaloneCompiledArtifacts
with compiled artifacts
- sym_shape_indices_map: dict mapping submod_name to
sym_shape_indices
- returns_tuple_map: dict mapping submod_name to
returns_tuple
"""
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
return None, None, None
from .caching import StandaloneCompiledArtifacts
from .piecewise_backend import PiecewiseBackend
standalone_compile_artifacts = StandaloneCompiledArtifacts()
sym_shape_indices_map = {}
returns_tuple_map = {}
for name, _ in self.split_gm.named_children():
# get the actual attribute (shadowed by PiecewiseBackend in __dict__)
child = getattr(self.split_gm, name)
# unwrap the static graph wrapper class if applicable
piecewise_backend = child.runnable if hasattr(child, "runnable") else child
if not isinstance(piecewise_backend, PiecewiseBackend):
continue
submod_name = name
sym_shape_indices_map[submod_name] = piecewise_backend.sym_shape_indices
returns_tuple_map[submod_name] = piecewise_backend.returns_tuple
for shape_str, bytes_data in piecewise_backend.to_bytes().items():
standalone_compile_artifacts.insert(submod_name, shape_str, bytes_data)
logger.debug(
"collected artifact for %s shape %s (%d bytes)",
submod_name,
shape_str,
len(bytes_data),
)
logger.info(
"collected artifacts: %d entries, %d artifacts, %d bytes total",
standalone_compile_artifacts.num_entries(),
standalone_compile_artifacts.num_artifacts(),
standalone_compile_artifacts.size_bytes(),
)
logger.debug(
"standalone compile artifact keys: %s",
list(standalone_compile_artifacts.submodule_bytes.keys()),
)
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
def configure_post_pass(self) -> None:
self.pass_manager.configure(self.vllm_config)
@@ -579,9 +738,11 @@ class VllmBackend:
)
self.inductor_config[self.pass_key] = self.pass_manager
def __call__(
self, graph: fx.GraphModule, example_inputs: Sequence[Any]
) -> VllmSerializableFunction:
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
from .caching import (
VllmSerializableFunction,
)
vllm_config = self.vllm_config
# Minimal hashing here with existing utilities, reused below.
@@ -721,6 +882,12 @@ class VllmBackend:
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
# keep a split_gm copy from BEFORE the interpreter replaces
# submodules with PiecewiseBackend -- used for serialization
original_split_gm = None
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
original_split_gm = deepcopy(self.split_gm)
from torch._dynamo.utils import lazy_format_graph_code
# depyf will hook lazy_format_graph_code and dump the graph
@@ -792,13 +959,21 @@ class VllmBackend:
)
self._called = True
graph_to_serialize = (
original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
)
if (
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
or not self.compilation_config.cudagraph_copy_inputs
):
return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm, self.is_encoder
graph_to_serialize,
example_inputs,
self.prefix,
self.split_gm,
is_encoder=self.is_encoder,
vllm_backend=self,
)
# index of tensors that have symbolic shapes (batch size)
@@ -806,7 +981,7 @@ class VllmBackend:
# symbolic shape only happens for input tensors.
from torch.fx.experimental.symbolic_shapes import is_symbolic
self.sym_tensor_indices = [
sym_tensor_indices = [
i
for i, x in enumerate(fake_args)
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
@@ -816,25 +991,18 @@ class VllmBackend:
# compiler managed cudagraph input buffers
# we assume the first run with symbolic shapes
# has the maximum size among all the tensors
self.input_buffers = [
example_inputs[x].clone() for x in self.sym_tensor_indices
]
# this is the callable we return to Dynamo to run
def copy_and_call(*args: Any) -> Any:
list_args = list(args)
for i, index in enumerate(self.sym_tensor_indices):
runtime_tensor = list_args[index]
runtime_shape = runtime_tensor.shape[0]
static_tensor = self.input_buffers[i][:runtime_shape]
# copy the tensor to the static buffer
static_tensor.copy_(runtime_tensor)
# replace the tensor in the list_args to the static buffer
list_args[index] = static_tensor
return self.split_gm(*list_args)
copy_and_call = make_copy_and_call(
sym_tensor_indices,
[example_inputs[x].clone() for x in sym_tensor_indices],
self.split_gm,
)
return VllmSerializableFunction(
graph, example_inputs, self.prefix, copy_and_call, self.is_encoder
graph_to_serialize,
example_inputs,
self.prefix,
copy_and_call,
is_encoder=self.is_encoder,
vllm_backend=self,
sym_tensor_indices=sym_tensor_indices,
)

View File

@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect
import os
import pickle
@@ -12,6 +13,7 @@ import torch
from torch.utils import _pytree as pytree
import vllm.envs as envs
from vllm.compilation.compiler_interface import get_inductor_factors
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.config.utils import hash_factors
from vllm.logger import init_logger
@@ -27,6 +29,121 @@ assert isinstance(SerializableCallable, type)
logger = init_logger(__name__)
class StandaloneCompiledArtifacts:
"""Storage for standalone compiled artifacts with content-based deduplication.
Deduplication works via a two-level indirection:
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
When inserting, we compute the SHA256 hash of the bytes. If the hash
already exists in `submodule_bytes_store`, we reuse the existing entry
rather than storing duplicate bytes. This is common because submodules
often compile to identical artifacts (e.g., identical transformer layers
split on attn)
"""
def __init__(self):
# dict from submodule name to byte hash
self.submodule_bytes = {}
# dict from byte hash to bytes
self.submodule_bytes_store = {}
# dict from byte hash to loaded module
self.loaded_submodule_store = {}
def insert(self, submod_name: str, shape: str, entry: bytes):
hasher = hashlib.sha256()
hasher.update(entry)
hex_digest = hasher.hexdigest()
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
if hex_digest not in self.submodule_bytes_store:
self.submodule_bytes_store[hex_digest] = entry
logger.debug(
"inserting new artifact for submod %s with shape %s "
"(%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
else:
logger.debug(
"reusing existing cache artifact for submod %s "
"with shape %s (%s bytes) at hash %s",
submod_name,
shape,
len(entry),
hex_digest,
)
def get(self, submod_name: str, shape: str) -> bytes:
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.submodule_bytes_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def get_loaded(self, submod_name: str, shape: str):
logger.debug(
"getting artifact for submod %s with shape %s",
submod_name,
shape,
)
return self.loaded_submodule_store[
self.submodule_bytes[f"{submod_name}_{shape}"]
]
def size_bytes(self) -> int:
return sum(len(entry) for entry in self.submodule_bytes_store.values())
def num_artifacts(self) -> int:
return len(self.submodule_bytes_store)
def num_entries(self) -> int:
return len(self.submodule_bytes)
def submodule_names(self) -> list[str]:
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
return list(dict.fromkeys(names))
def load_all(self) -> None:
import concurrent.futures
# check already loaded
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
return
from torch._inductor.standalone_compile import AOTCompiledArtifact
def _load_entry(entry_bytes) -> AOTCompiledArtifact:
entry = pickle.loads(entry_bytes)
return AOTCompiledArtifact.deserialize(entry)
with concurrent.futures.ThreadPoolExecutor() as executor:
entries = list(self.submodule_bytes_store.values())
loaded_entries = list(executor.map(_load_entry, entries))
for i, k in enumerate(self.submodule_bytes_store.keys()):
self.loaded_submodule_store[k] = loaded_entries[i]
logger.debug("loaded all %s submodules", self.num_artifacts())
def __getstate__(self):
return {
"submodule_bytes": self.submodule_bytes,
"submodule_bytes_store": self.submodule_bytes_store,
}
def __setstate__(self, state):
self.submodule_bytes = state["submodule_bytes"]
self.submodule_bytes_store = state["submodule_bytes_store"]
self.loaded_submodule_store = {}
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
"""
A wrapper around a compiled function by vllm. It will forward the tensor
@@ -46,6 +163,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
prefix: str,
optimized_call: Callable[..., Any],
is_encoder: bool = False,
vllm_backend: Any | None = None,
sym_tensor_indices: list[int] | None = None,
) -> None:
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
@@ -54,6 +173,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
self.optimized_call = optimized_call
self.is_encoder = is_encoder
self.shape_env = None
self.vllm_backend = vllm_backend
self.sym_tensor_indices = sym_tensor_indices
sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
)
@@ -74,9 +195,15 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
state.pop("vllm_backend", None)
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
for name, submod in state["graph_module"].named_children():
if hasattr(submod, "graph"):
for node in submod.graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
graph_reducer_override = GraphPickler.reducer_override
@@ -93,15 +220,36 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return type(None), ()
return graph_reducer_override(self, obj)
# Mask off tensor inputs since they are large and not needed.
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor, lambda _: None, state["example_inputs"]
)
if state.get("sym_tensor_indices"):
# put tensor inputs on meta device since their data
# isn't needed, yet we need the meta for make_copy_and_call
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
else:
# mask off all tensor inputs since they are large and not needed.
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor,
lambda inp: torch.empty_like(inp, device="meta"),
state["example_inputs"],
)
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
if compiled_fn.vllm_backend:
(
standalone_compile_artifacts,
sym_shape_indices_map,
returns_tuple_map,
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
state["standalone_compile_artifacts"] = standalone_compile_artifacts
state["sym_shape_indices_map"] = sym_shape_indices_map
state["returns_tuple_map"] = returns_tuple_map
return pickle.dumps(state)
@classmethod
@@ -111,15 +259,48 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from vllm.compilation.backends import VllmBackend
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["graph_module"].recompile()
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
returns_tuple_map = state.pop("returns_tuple_map", {})
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
assert standalone_compile_artifacts is not None
submod_names = standalone_compile_artifacts.submodule_names()
num_submods = len(submod_names)
num_artifacts = standalone_compile_artifacts.num_artifacts()
logger.info(
"reconstructing serializable fn from standalone compile "
"artifacts. num_artifacts=%d num_submods=%d",
num_artifacts,
num_submods,
)
fn = reconstruct_serializable_fn_from_mega_artifact(
state=state,
standalone_compile_artifacts=standalone_compile_artifacts,
vllm_config=get_current_vllm_config(),
sym_shape_indices_map=sym_shape_indices_map,
returns_tuple_map=returns_tuple_map,
)
logger.info(
"reconstructed serializable fn from standalone compile artifacts"
)
return fn
# Fall back to standard VllmBackend
from vllm.compilation.backends import VllmBackend
is_encoder = state.get("is_encoder", False)
vllm_backend = VllmBackend(
vllm_backend: VllmBackend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
@@ -152,7 +333,140 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return "VllmSerializableFunction"
def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
def reconstruct_serializable_fn_from_mega_artifact(
state: dict[str, Any],
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
vllm_config: VllmConfig,
sym_shape_indices_map: dict[str, list[int]],
returns_tuple_map: dict[str, bool],
) -> "VllmSerializableFunction":
"""Construct a VllmSerializableFunction from cached inductor artifacts.
This function reconstructs a callable model from pre-compiled inductor
artifacts without re-running the compilation. It:
1. Loads all cached artifacts
2. Builds compiled callables for each submodule/shape
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
4. Wraps with cudagraph if needed
5. Returns the final VllmSerializableFunction
Note: This function shares similar logic with PiecewiseCompileInterpreter
in backends.py. Both create PiecewiseBackend instances and wrap them with
cudagraph. The key difference is:
- this function: PiecewiseBackend receives pre-compiled runnables
(compiled_runnables is set, graph is None)
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
to compile (graph is set, compiled_runnables is None)
If modifying the backend creation/wrapping logic, consider updating both.
Args:
state: Deserialized state dict containing graph_module, example_inputs,
prefix, sym_tensor_indices, is_encoder, etc.
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
pre-compiled artifacts for each submodule/shape combination.
vllm_config: The vLLM configuration.
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
returns_tuple_map: Mapping from submod_name to returns_tuple.
Returns:
A VllmSerializableFunction that can be called directly.
"""
from vllm.compilation.backends import (
VllmBackend,
make_copy_and_call,
wrap_with_cudagraph_if_needed,
)
from vllm.compilation.piecewise_backend import PiecewiseBackend
prefix = state["prefix"]
is_encoder = state.get("is_encoder", False)
split_gm = state["graph_module"]
compilation_config = vllm_config.compilation_config
standalone_compile_artifacts.load_all()
submod_names = standalone_compile_artifacts.submodule_names()
compiled_callables: dict[str, dict[str, Callable]] = {}
for cache_key in standalone_compile_artifacts.submodule_bytes:
submod_name, shape_str = cache_key.rsplit("_", 1)
compiled_callables.setdefault(submod_name, {})[shape_str] = (
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
)
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
os.makedirs(dummy_cache_dir, exist_ok=True)
vllm_backend.compiler_manager.initialize_cache(
cache_dir=dummy_cache_dir,
disable_cache=True,
prefix=prefix,
)
# spot check that cached submodules exist in the graph structure
graph_children = {name for name, _ in split_gm.named_children()}
missing = set(submod_names) - graph_children
assert not missing, (
f"artifacts reference submodules not in graph: {missing}. "
f"graph has: {sorted(graph_children)}"
)
for i, submod_name in enumerate(submod_names):
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
sym_shape_indices = sym_shape_indices_map[submod_name]
returns_tuple = returns_tuple_map[submod_name]
runnables = compiled_callables[submod_name]
piecewise_backend = PiecewiseBackend(
graph=None, # not needed for cached artifacts
vllm_config=vllm_config,
piecewise_compile_index=i,
total_piecewise_compiles=len(submod_names),
sym_shape_indices=sym_shape_indices,
vllm_backend=vllm_backend,
returns_tuple=returns_tuple,
compiled_runnables=runnables,
)
is_first = i == 0
is_last = i == len(submod_names) - 1
wrapped_backend = wrap_with_cudagraph_if_needed(
piecewise_backend,
vllm_config,
compilation_config,
is_first,
is_last,
)
split_gm.__dict__[submod_name] = wrapped_backend
logger.debug(
"Replaced submodule %s with piecewise backend from cache",
submod_name,
)
if compilation_config.cudagraph_copy_inputs:
sym_tensor_indices = state["sym_tensor_indices"]
input_buffers = [
torch.empty_like(
state["example_inputs"][idx], device=vllm_config.device_config.device
)
for idx in sym_tensor_indices
]
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
else:
optimized_call = split_gm
fn = VllmSerializableFunction(
**state,
optimized_call=optimized_call,
vllm_backend=None,
)
return fn
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
@@ -163,6 +477,11 @@ def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
# model is created)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)
# 2. inductor factors if applicable
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
factors.extend(get_inductor_factors())
return factors

View File

@@ -16,9 +16,12 @@ import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
logger = init_logger(__name__)
class CompilerInterface:
"""
@@ -230,12 +233,42 @@ class InductorStandaloneAdaptor(CompilerInterface):
from torch._inductor import standalone_compile
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config},
)
supports_aot = is_torch_equal_or_newer("2.10.0.dev")
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
logger.error(
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
"is enabled but PyTorch version does not support 'aot' "
"parameter in standalone_compile. This requires PyTorch "
"2.10.0+. Falling back to non-AOT mode."
)
compile_kwargs = {
"dynamic_shapes": dynamic_shapes,
"options": {
"config_patches": current_config,
},
}
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
# only add 'aot' parameter if both supported and enabled...
# this will set bundled_autograd_cache
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
if use_aot:
compile_kwargs["aot"] = True # type: ignore[assignment]
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
from torch._inductor.standalone_compile import AOTCompiledArtifact
assert isinstance(compiled_graph, AOTCompiledArtifact)
assert hasattr(compiled_graph, "serialize")
# just return the compiled graph and a key
# since we can serialize the bytes using to_bytes
# and reload it using the key when reading
return compiled_graph, None
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
@@ -619,7 +652,8 @@ def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
def set_functorch_config() -> None:
torch._functorch.config.bundled_autograd_cache = False
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
torch._functorch.config.bundled_autograd_cache = False
class EagerAdaptor(CompilerInterface):

View File

@@ -320,7 +320,7 @@ def _support_torch_compile(
return
self._check_shape_invariants = shape_invariants
self.was_aot_compile_fn_loaded_from_disk = False
compilation_counter.num_models_seen += 1
self.compiled = False
@@ -417,9 +417,9 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
from .caching import compilation_config_hash_factors
from .caching import aot_compile_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
factors: list[str] = aot_compile_hash_factors(self.vllm_config)
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
@@ -446,6 +446,7 @@ def _support_torch_compile(
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
loaded_fn.disable_guard_check()
self.aot_compiled_fn = loaded_fn
self.was_aot_compile_fn_loaded_from_disk = True
except Exception as e:
if os.path.exists(aot_compilation_path):
logger.warning(
@@ -547,26 +548,45 @@ def _support_torch_compile(
logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False
if use_aot_compile:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
assert aot_compilation_path is not None
assert cache_dir is not None
try:
os.makedirs(cache_dir, exist_ok=True)
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
except Exception as e:
logger.warning(
"Cannot save aot compilation to path %s, error: %s",
aot_compilation_path,
str(e),
)
from vllm.compilation.backends import set_on_compilation_complete
# store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir
# set callback in context so it's available when compilation completes
with set_on_compilation_complete(self.save_aot_compiled_function):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
self.compiled = True
return output
# triggers VllmSerializableFunction.serialize()
def save_aot_compiled_function(self):
if self.was_aot_compile_fn_loaded_from_disk:
logger.debug("AOT compiled function was loaded from cache, skipping save")
return
assert (
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
)
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
try:
os.makedirs(self._aot_cache_dir, exist_ok=True)
self.aot_compiled_fn.save_compiled_function(self._aot_compilation_path)
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
except Exception as e:
logger.warning(
"unable to save AOT compiled function to %s: %s",
self._aot_compilation_path,
e,
)
cls.__call__ = __call__
cls.save_aot_compiled_function = save_aot_compiled_function
return cls

View File

@@ -2,10 +2,15 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import io
import pickle
from collections.abc import Callable
from pickle import Pickler
from typing import Any
import torch._functorch.config
import torch.fx as fx
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
@@ -26,12 +31,14 @@ class RangeEntry:
class PiecewiseBackend:
def __init__(
self,
graph: fx.GraphModule,
graph: fx.GraphModule | None,
vllm_config: VllmConfig,
piecewise_compile_index: int,
total_piecewise_compiles: int,
sym_shape_indices: list[int],
vllm_backend: VllmBackend,
returns_tuple: bool,
compiled_runnables: dict[str, Callable] | None = None,
):
"""
The backend for piecewise compilation.
@@ -41,13 +48,28 @@ class PiecewiseBackend:
We will compile `self.graph` once for the general shape,
and then compile for different shapes specified in
`compilation_config.compile_sizes`.
This class supports two mutually exclusive modes:
1. Compilation (graph is set, compiled_runnables is None):
Used during initial compilation when we have the FX graph
and need to compile it for each shape range.
2. Precompilation (graph is None, compiled_runnables is set):
Used when loading from cache/AOT artifacts where we already
have pre-compiled callables and don't need the original graph.
Exactly one of graph or compiled_runnables must be provided.
"""
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
"exactly one of graph and compiled_runnables should be set."
)
self.graph = graph
self.vllm_config = vllm_config
self.compilation_config = vllm_config.compilation_config
self.piecewise_compile_index = piecewise_compile_index
self.total_piecewise_compiles = total_piecewise_compiles
self.vllm_backend = vllm_backend
self.compiled_runnables = compiled_runnables
self.is_first_graph = piecewise_compile_index == 0
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
@@ -77,6 +99,7 @@ class PiecewiseBackend:
logger.debug_once(log_string)
self.sym_shape_indices = sym_shape_indices
self.returns_tuple = returns_tuple
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
@@ -108,12 +131,71 @@ class PiecewiseBackend:
compile_range=range,
)
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
from vllm.compilation.backends import _on_compilation_complete_callback
self.on_compilation_complete = _on_compilation_complete_callback.get()
def get_compiled_graph_wrapper(self, compiled_graph):
def compiled_graph_wrapper(*args):
graph_output = compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
self.vllm_backend.compiler_manager.save_to_file()
end_monitoring_torch_compile(self.vllm_config)
# Call the completion callback (e.g., to save AOT compiled function)
if self.on_compilation_complete is not None:
self.on_compilation_complete()
def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj):
if isinstance(obj, CachingAutotuner):
obj.prepare_for_pickle()
return pickle.loads, (
pickle.dumps(
obj,
),
)
return NotImplemented
def serialize(fn) -> bytes:
assert hasattr(fn, "serialize"), "fn must have serialize method"
with torch._functorch.config.patch("bundled_autograd_cache", True):
entry = fn.serialize()
f = io.BytesIO()
StandaloneCompiledArtifactsPickler(f).dump(entry)
result = f.getvalue()
return result
out = {}
for range_key, entry in self.range_entries.items():
if not entry.compiled:
logger.debug(
"entry with range %s not compiled, so cannot get its bytes",
range_key,
)
continue
if hasattr(entry.runnable, "serialize"):
out[str(range_key)] = serialize(entry.runnable)
return out
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile
@@ -127,6 +209,7 @@ class PiecewiseBackend:
# non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
assert self.graph is not None
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
@@ -140,28 +223,37 @@ class PiecewiseBackend:
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
if not range_entry.compiled:
if self.compiled_runnables is not None:
range_entry.runnable = self.get_compiled_graph_wrapper(
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
# args are real arguments
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args_list = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else list(args)
)
with (
torch._functorch.config.patch("bundled_autograd_cache", True),
):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
# args are real arguments
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args_list = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else list(args)
)
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None: