[AOT compilation] support torch.compile inductor artifacts in VllmCompiledFunction (#25205)
Signed-off-by: dolpm <34420038+dolpm@users.noreply.github.com>
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
@@ -34,7 +35,6 @@ from vllm.platforms import current_platform
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .caching import VllmSerializableFunction
|
||||
from .compiler_interface import (
|
||||
CompilerInterface,
|
||||
EagerAdaptor,
|
||||
@@ -49,7 +49,48 @@ from .pass_manager import PostGradPassManager
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_copy_and_call(
|
||||
sym_tensor_indices: list[int],
|
||||
input_buffers: list[torch.Tensor | None],
|
||||
callable_fn: Callable[..., Any],
|
||||
) -> Callable[..., Any]:
|
||||
"""Create a wrapper that copies inputs to static buffers before calling.
|
||||
|
||||
This is used for cudagraph input copying where we need to copy dynamic
|
||||
tensors to static buffers before invoking the compiled graph.
|
||||
|
||||
Args:
|
||||
sym_tensor_indices: Indices of tensors with symbolic shapes
|
||||
input_buffers: List of static buffers (can contain None for lazy init)
|
||||
callable_fn: The compiled function to call
|
||||
|
||||
Returns:
|
||||
A wrapper function that copies inputs and calls the compiled function
|
||||
"""
|
||||
|
||||
def copy_and_call(*args):
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
runtime_shape = runtime_tensor.shape[0]
|
||||
|
||||
# lazy initialization of buffer on first call
|
||||
if input_buffers[i] is None:
|
||||
input_buffers[i] = runtime_tensor.clone()
|
||||
|
||||
static_tensor = input_buffers[i][:runtime_shape] # type: ignore[index]
|
||||
static_tensor.copy_(runtime_tensor)
|
||||
list_args[index] = static_tensor
|
||||
return callable_fn(*list_args)
|
||||
|
||||
return copy_and_call
|
||||
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
assert not envs.VLLM_USE_MEGA_AOT_ARTIFACT or envs.VLLM_USE_STANDALONE_COMPILE, (
|
||||
"VLLM_USE_MEGA_AOT_ARTIFACT=1 requires VLLM_USE_STANDALONE_COMPILE=1"
|
||||
)
|
||||
|
||||
if compilation_config.backend == "inductor":
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
@@ -355,6 +396,60 @@ def split_graph(
|
||||
compilation_start_time = 0.0
|
||||
|
||||
|
||||
def wrap_with_cudagraph_if_needed(
|
||||
piecewise_backend: Any,
|
||||
vllm_config: VllmConfig,
|
||||
compilation_config: CompilationConfig,
|
||||
is_first_graph: bool,
|
||||
is_last_graph: bool,
|
||||
) -> Any:
|
||||
"""
|
||||
Wrap a piecewise backend with CUDA graph wrapper if needed.
|
||||
This function is shared between VllmBackend and
|
||||
construct_serializable_fn_from_inductor_cache.
|
||||
|
||||
Args:
|
||||
piecewise_backend: The backend to wrap
|
||||
vllm_config: The vLLM configuration
|
||||
compilation_config: The compilation configuration
|
||||
is_first_graph: Whether this is the first graph in the sequence
|
||||
is_last_graph: Whether this is the last graph in the sequence
|
||||
|
||||
Returns:
|
||||
The wrapped backend if CUDA graphs are enabled, otherwise the original backend
|
||||
"""
|
||||
if (
|
||||
not compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
or compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
return piecewise_backend
|
||||
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
from .cuda_graph import CUDAGraphOptions
|
||||
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
return static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
vllm_config=vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=is_first_graph,
|
||||
gc_disable=not is_first_graph,
|
||||
weak_ref_output=is_last_graph,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
|
||||
It runs the given graph with fake inputs, and compile some
|
||||
@@ -365,6 +460,18 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
it will be used to determine the order of the compiled piecewise
|
||||
graphs. The first graph will handle logging, and the last graph
|
||||
has some special cudagraph output handling.
|
||||
|
||||
Note: This class shares similar logic with
|
||||
reconstruct_serializable_fn_from_mega_artifact in caching.py.
|
||||
Both create PiecewiseBackend instances and wrap them with cudagraph.
|
||||
The key difference is:
|
||||
- reconstruct_serializable_fn_from_mega_artifact: PiecewiseBackend receives
|
||||
pre-compiled runnables (compiled_runnables is set, graph is None)
|
||||
- this class: PiecewiseBackend receives the FX graph to compile
|
||||
(graph is set, compiled_runnables is None)
|
||||
|
||||
|
||||
If modifying the backend creation/wrapping logic, consider updating both.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -413,6 +520,8 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
]
|
||||
|
||||
# Lazy import here to avoid circular import
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
@@ -422,38 +531,16 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
len(self.compile_submod_names),
|
||||
sym_shape_indices,
|
||||
self.vllm_backend,
|
||||
graph_returns_tuple(submod),
|
||||
)
|
||||
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode.has_piecewise_cudagraphs()
|
||||
and not self.compilation_config.use_inductor_graph_partition
|
||||
):
|
||||
# We're using Dynamo-based piecewise splitting, so we wrap
|
||||
# the whole subgraph with a static graph wrapper.
|
||||
from .cuda_graph import CUDAGraphOptions
|
||||
|
||||
# resolve the static graph wrapper class (e.g. CUDAGraphWrapper
|
||||
# class) as platform dependent.
|
||||
static_graph_wrapper_class = resolve_obj_by_qualname(
|
||||
current_platform.get_static_graph_wrapper_cls()
|
||||
)
|
||||
|
||||
# Always assign PIECEWISE runtime mode to the
|
||||
# CUDAGraphWrapper for piecewise_backend, to distinguish
|
||||
# it from the FULL cudagraph runtime mode, no matter it
|
||||
# is wrapped on a full or piecewise fx graph.
|
||||
self.module.__dict__[target] = static_graph_wrapper_class(
|
||||
runnable=piecewise_backend,
|
||||
vllm_config=self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
weak_ref_output=piecewise_backend.is_last_graph,
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.module.__dict__[target] = piecewise_backend
|
||||
self.module.__dict__[target] = wrap_with_cudagraph_if_needed(
|
||||
piecewise_backend,
|
||||
self.vllm_config,
|
||||
self.compilation_config,
|
||||
piecewise_backend.is_first_graph,
|
||||
piecewise_backend.is_last_graph,
|
||||
)
|
||||
|
||||
compilation_counter.num_piecewise_capturable_graphs_seen += 1
|
||||
|
||||
@@ -465,6 +552,21 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
|
||||
model_tag: str = "backbone"
|
||||
model_is_encoder: bool = False
|
||||
|
||||
_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = (
|
||||
contextvars.ContextVar("on_compilation_complete_callback", default=None)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_on_compilation_complete(
|
||||
callback: Callable[[], None],
|
||||
) -> Generator[None, None, None]:
|
||||
token = _on_compilation_complete_callback.set(callback)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_on_compilation_complete_callback.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
|
||||
@@ -509,8 +611,6 @@ class VllmBackend:
|
||||
returned_callable: Callable[..., Any]
|
||||
# Inductor passes to run on the graph pre-defunctionalization
|
||||
post_grad_passes: Sequence[Callable[..., Any]]
|
||||
sym_tensor_indices: list[int]
|
||||
input_buffers: list[torch.Tensor]
|
||||
compiler_manager: CompilerManager
|
||||
# Copy of CompilationConfig.inductor_compile_config +
|
||||
# an entry for PostGradPassManager
|
||||
@@ -539,9 +639,6 @@ class VllmBackend:
|
||||
)()
|
||||
self.pass_key = current_platform.pass_key
|
||||
|
||||
self.sym_tensor_indices = []
|
||||
self.input_buffers = []
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
@@ -558,6 +655,68 @@ class VllmBackend:
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
||||
def collect_standalone_compile_artifacts(
|
||||
self,
|
||||
) -> tuple[Any, dict[str, list[int]] | None, dict[str, bool] | None]:
|
||||
"""Collect inductor cache artifacts from all piecewise backends.
|
||||
|
||||
Returns:
|
||||
tuple: (standalone_compile_artifacts, sym_shape_indices_map,
|
||||
returns_tuple_map)
|
||||
- standalone_compile_artifacts: StandaloneCompiledArtifacts
|
||||
with compiled artifacts
|
||||
- sym_shape_indices_map: dict mapping submod_name to
|
||||
sym_shape_indices
|
||||
- returns_tuple_map: dict mapping submod_name to
|
||||
returns_tuple
|
||||
"""
|
||||
|
||||
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
return None, None, None
|
||||
|
||||
from .caching import StandaloneCompiledArtifacts
|
||||
from .piecewise_backend import PiecewiseBackend
|
||||
|
||||
standalone_compile_artifacts = StandaloneCompiledArtifacts()
|
||||
sym_shape_indices_map = {}
|
||||
returns_tuple_map = {}
|
||||
|
||||
for name, _ in self.split_gm.named_children():
|
||||
# get the actual attribute (shadowed by PiecewiseBackend in __dict__)
|
||||
child = getattr(self.split_gm, name)
|
||||
# unwrap the static graph wrapper class if applicable
|
||||
piecewise_backend = child.runnable if hasattr(child, "runnable") else child
|
||||
|
||||
if not isinstance(piecewise_backend, PiecewiseBackend):
|
||||
continue
|
||||
|
||||
submod_name = name
|
||||
sym_shape_indices_map[submod_name] = piecewise_backend.sym_shape_indices
|
||||
returns_tuple_map[submod_name] = piecewise_backend.returns_tuple
|
||||
|
||||
for shape_str, bytes_data in piecewise_backend.to_bytes().items():
|
||||
standalone_compile_artifacts.insert(submod_name, shape_str, bytes_data)
|
||||
logger.debug(
|
||||
"collected artifact for %s shape %s (%d bytes)",
|
||||
submod_name,
|
||||
shape_str,
|
||||
len(bytes_data),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"collected artifacts: %d entries, %d artifacts, %d bytes total",
|
||||
standalone_compile_artifacts.num_entries(),
|
||||
standalone_compile_artifacts.num_artifacts(),
|
||||
standalone_compile_artifacts.size_bytes(),
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"standalone compile artifact keys: %s",
|
||||
list(standalone_compile_artifacts.submodule_bytes.keys()),
|
||||
)
|
||||
|
||||
return standalone_compile_artifacts, sym_shape_indices_map, returns_tuple_map
|
||||
|
||||
def configure_post_pass(self) -> None:
|
||||
self.pass_manager.configure(self.vllm_config)
|
||||
|
||||
@@ -579,9 +738,11 @@ class VllmBackend:
|
||||
)
|
||||
self.inductor_config[self.pass_key] = self.pass_manager
|
||||
|
||||
def __call__(
|
||||
self, graph: fx.GraphModule, example_inputs: Sequence[Any]
|
||||
) -> VllmSerializableFunction:
|
||||
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
|
||||
from .caching import (
|
||||
VllmSerializableFunction,
|
||||
)
|
||||
|
||||
vllm_config = self.vllm_config
|
||||
# Minimal hashing here with existing utilities, reused below.
|
||||
|
||||
@@ -721,6 +882,12 @@ class VllmBackend:
|
||||
|
||||
self.split_gm, self.piecewise_graphs = split_graph(graph, fx_split_ops)
|
||||
|
||||
# keep a split_gm copy from BEFORE the interpreter replaces
|
||||
# submodules with PiecewiseBackend -- used for serialization
|
||||
original_split_gm = None
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
original_split_gm = deepcopy(self.split_gm)
|
||||
|
||||
from torch._dynamo.utils import lazy_format_graph_code
|
||||
|
||||
# depyf will hook lazy_format_graph_code and dump the graph
|
||||
@@ -792,13 +959,21 @@ class VllmBackend:
|
||||
)
|
||||
|
||||
self._called = True
|
||||
graph_to_serialize = (
|
||||
original_split_gm if envs.VLLM_USE_MEGA_AOT_ARTIFACT else self.graph
|
||||
)
|
||||
|
||||
if (
|
||||
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
|
||||
or not self.compilation_config.cudagraph_copy_inputs
|
||||
):
|
||||
return VllmSerializableFunction(
|
||||
graph, example_inputs, self.prefix, self.split_gm, self.is_encoder
|
||||
graph_to_serialize,
|
||||
example_inputs,
|
||||
self.prefix,
|
||||
self.split_gm,
|
||||
is_encoder=self.is_encoder,
|
||||
vllm_backend=self,
|
||||
)
|
||||
|
||||
# index of tensors that have symbolic shapes (batch size)
|
||||
@@ -806,7 +981,7 @@ class VllmBackend:
|
||||
# symbolic shape only happens for input tensors.
|
||||
from torch.fx.experimental.symbolic_shapes import is_symbolic
|
||||
|
||||
self.sym_tensor_indices = [
|
||||
sym_tensor_indices = [
|
||||
i
|
||||
for i, x in enumerate(fake_args)
|
||||
if isinstance(x, torch._subclasses.fake_tensor.FakeTensor)
|
||||
@@ -816,25 +991,18 @@ class VllmBackend:
|
||||
# compiler managed cudagraph input buffers
|
||||
# we assume the first run with symbolic shapes
|
||||
# has the maximum size among all the tensors
|
||||
self.input_buffers = [
|
||||
example_inputs[x].clone() for x in self.sym_tensor_indices
|
||||
]
|
||||
|
||||
# this is the callable we return to Dynamo to run
|
||||
def copy_and_call(*args: Any) -> Any:
|
||||
list_args = list(args)
|
||||
for i, index in enumerate(self.sym_tensor_indices):
|
||||
runtime_tensor = list_args[index]
|
||||
runtime_shape = runtime_tensor.shape[0]
|
||||
static_tensor = self.input_buffers[i][:runtime_shape]
|
||||
|
||||
# copy the tensor to the static buffer
|
||||
static_tensor.copy_(runtime_tensor)
|
||||
|
||||
# replace the tensor in the list_args to the static buffer
|
||||
list_args[index] = static_tensor
|
||||
return self.split_gm(*list_args)
|
||||
copy_and_call = make_copy_and_call(
|
||||
sym_tensor_indices,
|
||||
[example_inputs[x].clone() for x in sym_tensor_indices],
|
||||
self.split_gm,
|
||||
)
|
||||
|
||||
return VllmSerializableFunction(
|
||||
graph, example_inputs, self.prefix, copy_and_call, self.is_encoder
|
||||
graph_to_serialize,
|
||||
example_inputs,
|
||||
self.prefix,
|
||||
copy_and_call,
|
||||
is_encoder=self.is_encoder,
|
||||
vllm_backend=self,
|
||||
sym_tensor_indices=sym_tensor_indices,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
import inspect
|
||||
import os
|
||||
import pickle
|
||||
@@ -12,6 +13,7 @@ import torch
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.compiler_interface import get_inductor_factors
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.config.utils import hash_factors
|
||||
from vllm.logger import init_logger
|
||||
@@ -27,6 +29,121 @@ assert isinstance(SerializableCallable, type)
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StandaloneCompiledArtifacts:
|
||||
"""Storage for standalone compiled artifacts with content-based deduplication.
|
||||
|
||||
Deduplication works via a two-level indirection:
|
||||
1. `submodule_bytes` maps "{submod_name}_{shape}" -> SHA256 hash
|
||||
2. `submodule_bytes_store` maps SHA256 hash -> actual bytes
|
||||
|
||||
When inserting, we compute the SHA256 hash of the bytes. If the hash
|
||||
already exists in `submodule_bytes_store`, we reuse the existing entry
|
||||
rather than storing duplicate bytes. This is common because submodules
|
||||
often compile to identical artifacts (e.g., identical transformer layers
|
||||
split on attn)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# dict from submodule name to byte hash
|
||||
self.submodule_bytes = {}
|
||||
# dict from byte hash to bytes
|
||||
self.submodule_bytes_store = {}
|
||||
# dict from byte hash to loaded module
|
||||
self.loaded_submodule_store = {}
|
||||
|
||||
def insert(self, submod_name: str, shape: str, entry: bytes):
|
||||
hasher = hashlib.sha256()
|
||||
hasher.update(entry)
|
||||
hex_digest = hasher.hexdigest()
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"] = hex_digest
|
||||
if hex_digest not in self.submodule_bytes_store:
|
||||
self.submodule_bytes_store[hex_digest] = entry
|
||||
logger.debug(
|
||||
"inserting new artifact for submod %s with shape %s "
|
||||
"(%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"reusing existing cache artifact for submod %s "
|
||||
"with shape %s (%s bytes) at hash %s",
|
||||
submod_name,
|
||||
shape,
|
||||
len(entry),
|
||||
hex_digest,
|
||||
)
|
||||
|
||||
def get(self, submod_name: str, shape: str) -> bytes:
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.submodule_bytes_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def get_loaded(self, submod_name: str, shape: str):
|
||||
logger.debug(
|
||||
"getting artifact for submod %s with shape %s",
|
||||
submod_name,
|
||||
shape,
|
||||
)
|
||||
return self.loaded_submodule_store[
|
||||
self.submodule_bytes[f"{submod_name}_{shape}"]
|
||||
]
|
||||
|
||||
def size_bytes(self) -> int:
|
||||
return sum(len(entry) for entry in self.submodule_bytes_store.values())
|
||||
|
||||
def num_artifacts(self) -> int:
|
||||
return len(self.submodule_bytes_store)
|
||||
|
||||
def num_entries(self) -> int:
|
||||
return len(self.submodule_bytes)
|
||||
|
||||
def submodule_names(self) -> list[str]:
|
||||
# get unique "{submod_name}" from "{submod_name}_{shape}", preserving order
|
||||
names = [cache_key.rsplit("_", 1)[0] for cache_key in self.submodule_bytes]
|
||||
return list(dict.fromkeys(names))
|
||||
|
||||
def load_all(self) -> None:
|
||||
import concurrent.futures
|
||||
|
||||
# check already loaded
|
||||
if len(self.loaded_submodule_store) == len(self.submodule_bytes_store):
|
||||
return
|
||||
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
def _load_entry(entry_bytes) -> AOTCompiledArtifact:
|
||||
entry = pickle.loads(entry_bytes)
|
||||
return AOTCompiledArtifact.deserialize(entry)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
entries = list(self.submodule_bytes_store.values())
|
||||
loaded_entries = list(executor.map(_load_entry, entries))
|
||||
|
||||
for i, k in enumerate(self.submodule_bytes_store.keys()):
|
||||
self.loaded_submodule_store[k] = loaded_entries[i]
|
||||
|
||||
logger.debug("loaded all %s submodules", self.num_artifacts())
|
||||
|
||||
def __getstate__(self):
|
||||
return {
|
||||
"submodule_bytes": self.submodule_bytes,
|
||||
"submodule_bytes_store": self.submodule_bytes_store,
|
||||
}
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.submodule_bytes = state["submodule_bytes"]
|
||||
self.submodule_bytes_store = state["submodule_bytes_store"]
|
||||
self.loaded_submodule_store = {}
|
||||
|
||||
|
||||
class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
"""
|
||||
A wrapper around a compiled function by vllm. It will forward the tensor
|
||||
@@ -46,6 +163,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
prefix: str,
|
||||
optimized_call: Callable[..., Any],
|
||||
is_encoder: bool = False,
|
||||
vllm_backend: Any | None = None,
|
||||
sym_tensor_indices: list[int] | None = None,
|
||||
) -> None:
|
||||
assert isinstance(graph_module, torch.fx.GraphModule)
|
||||
self.graph_module = graph_module
|
||||
@@ -54,6 +173,8 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
self.optimized_call = optimized_call
|
||||
self.is_encoder = is_encoder
|
||||
self.shape_env = None
|
||||
self.vllm_backend = vllm_backend
|
||||
self.sym_tensor_indices = sym_tensor_indices
|
||||
sym_input = next(
|
||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||
)
|
||||
@@ -74,9 +195,15 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
state = compiled_fn.__dict__.copy()
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
state.pop("vllm_backend", None)
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
for name, submod in state["graph_module"].named_children():
|
||||
if hasattr(submod, "graph"):
|
||||
for node in submod.graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
|
||||
graph_reducer_override = GraphPickler.reducer_override
|
||||
|
||||
@@ -93,15 +220,36 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
return type(None), ()
|
||||
return graph_reducer_override(self, obj)
|
||||
|
||||
# Mask off tensor inputs since they are large and not needed.
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor, lambda _: None, state["example_inputs"]
|
||||
)
|
||||
if state.get("sym_tensor_indices"):
|
||||
# put tensor inputs on meta device since their data
|
||||
# isn't needed, yet we need the meta for make_copy_and_call
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
else:
|
||||
# mask off all tensor inputs since they are large and not needed.
|
||||
state["example_inputs"] = pytree.tree_map_only(
|
||||
torch.Tensor,
|
||||
lambda inp: torch.empty_like(inp, device="meta"),
|
||||
state["example_inputs"],
|
||||
)
|
||||
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
|
||||
state["graph_module"] = GraphPickler.dumps(
|
||||
state["graph_module"], Options(ops_filter=None)
|
||||
)
|
||||
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
|
||||
|
||||
if compiled_fn.vllm_backend:
|
||||
(
|
||||
standalone_compile_artifacts,
|
||||
sym_shape_indices_map,
|
||||
returns_tuple_map,
|
||||
) = compiled_fn.vllm_backend.collect_standalone_compile_artifacts()
|
||||
state["standalone_compile_artifacts"] = standalone_compile_artifacts
|
||||
state["sym_shape_indices_map"] = sym_shape_indices_map
|
||||
state["returns_tuple_map"] = returns_tuple_map
|
||||
return pickle.dumps(state)
|
||||
|
||||
@classmethod
|
||||
@@ -111,15 +259,48 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
from torch.fx._graph_pickler import GraphPickler
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
state = pickle.loads(data)
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
|
||||
state["graph_module"].recompile()
|
||||
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
|
||||
|
||||
standalone_compile_artifacts = state.pop("standalone_compile_artifacts", None)
|
||||
sym_shape_indices_map = state.pop("sym_shape_indices_map", {})
|
||||
returns_tuple_map = state.pop("returns_tuple_map", {})
|
||||
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
assert standalone_compile_artifacts is not None
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
num_submods = len(submod_names)
|
||||
num_artifacts = standalone_compile_artifacts.num_artifacts()
|
||||
|
||||
logger.info(
|
||||
"reconstructing serializable fn from standalone compile "
|
||||
"artifacts. num_artifacts=%d num_submods=%d",
|
||||
num_artifacts,
|
||||
num_submods,
|
||||
)
|
||||
|
||||
fn = reconstruct_serializable_fn_from_mega_artifact(
|
||||
state=state,
|
||||
standalone_compile_artifacts=standalone_compile_artifacts,
|
||||
vllm_config=get_current_vllm_config(),
|
||||
sym_shape_indices_map=sym_shape_indices_map,
|
||||
returns_tuple_map=returns_tuple_map,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"reconstructed serializable fn from standalone compile artifacts"
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
# Fall back to standard VllmBackend
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
vllm_backend = VllmBackend(
|
||||
vllm_backend: VllmBackend = VllmBackend(
|
||||
get_current_vllm_config(), state["prefix"], is_encoder
|
||||
)
|
||||
|
||||
@@ -152,7 +333,140 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
return "VllmSerializableFunction"
|
||||
|
||||
|
||||
def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||
def reconstruct_serializable_fn_from_mega_artifact(
|
||||
state: dict[str, Any],
|
||||
standalone_compile_artifacts: "StandaloneCompiledArtifacts",
|
||||
vllm_config: VllmConfig,
|
||||
sym_shape_indices_map: dict[str, list[int]],
|
||||
returns_tuple_map: dict[str, bool],
|
||||
) -> "VllmSerializableFunction":
|
||||
"""Construct a VllmSerializableFunction from cached inductor artifacts.
|
||||
|
||||
This function reconstructs a callable model from pre-compiled inductor
|
||||
artifacts without re-running the compilation. It:
|
||||
1. Loads all cached artifacts
|
||||
2. Builds compiled callables for each submodule/shape
|
||||
3. Creates PiecewiseBackend instances that dispatch to cached artifacts
|
||||
4. Wraps with cudagraph if needed
|
||||
5. Returns the final VllmSerializableFunction
|
||||
|
||||
Note: This function shares similar logic with PiecewiseCompileInterpreter
|
||||
in backends.py. Both create PiecewiseBackend instances and wrap them with
|
||||
cudagraph. The key difference is:
|
||||
- this function: PiecewiseBackend receives pre-compiled runnables
|
||||
(compiled_runnables is set, graph is None)
|
||||
- PiecewiseCompileInterpreter: PiecewiseBackend receives the FX graph
|
||||
to compile (graph is set, compiled_runnables is None)
|
||||
|
||||
If modifying the backend creation/wrapping logic, consider updating both.
|
||||
|
||||
Args:
|
||||
state: Deserialized state dict containing graph_module, example_inputs,
|
||||
prefix, sym_tensor_indices, is_encoder, etc.
|
||||
standalone_compile_artifacts: The StandaloneCompiledArtifacts containing
|
||||
pre-compiled artifacts for each submodule/shape combination.
|
||||
vllm_config: The vLLM configuration.
|
||||
sym_shape_indices_map: Mapping from submod_name to sym_shape_indices.
|
||||
returns_tuple_map: Mapping from submod_name to returns_tuple.
|
||||
|
||||
Returns:
|
||||
A VllmSerializableFunction that can be called directly.
|
||||
"""
|
||||
from vllm.compilation.backends import (
|
||||
VllmBackend,
|
||||
make_copy_and_call,
|
||||
wrap_with_cudagraph_if_needed,
|
||||
)
|
||||
from vllm.compilation.piecewise_backend import PiecewiseBackend
|
||||
|
||||
prefix = state["prefix"]
|
||||
is_encoder = state.get("is_encoder", False)
|
||||
split_gm = state["graph_module"]
|
||||
compilation_config = vllm_config.compilation_config
|
||||
|
||||
standalone_compile_artifacts.load_all()
|
||||
|
||||
submod_names = standalone_compile_artifacts.submodule_names()
|
||||
compiled_callables: dict[str, dict[str, Callable]] = {}
|
||||
|
||||
for cache_key in standalone_compile_artifacts.submodule_bytes:
|
||||
submod_name, shape_str = cache_key.rsplit("_", 1)
|
||||
compiled_callables.setdefault(submod_name, {})[shape_str] = (
|
||||
standalone_compile_artifacts.get_loaded(submod_name, shape_str)
|
||||
)
|
||||
|
||||
vllm_backend = VllmBackend(vllm_config, prefix, is_encoder)
|
||||
dummy_cache_dir = os.path.join(envs.VLLM_CACHE_ROOT, "dummy_cache")
|
||||
os.makedirs(dummy_cache_dir, exist_ok=True)
|
||||
vllm_backend.compiler_manager.initialize_cache(
|
||||
cache_dir=dummy_cache_dir,
|
||||
disable_cache=True,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
# spot check that cached submodules exist in the graph structure
|
||||
graph_children = {name for name, _ in split_gm.named_children()}
|
||||
missing = set(submod_names) - graph_children
|
||||
assert not missing, (
|
||||
f"artifacts reference submodules not in graph: {missing}. "
|
||||
f"graph has: {sorted(graph_children)}"
|
||||
)
|
||||
|
||||
for i, submod_name in enumerate(submod_names):
|
||||
assert submod_name in sym_shape_indices_map and submod_name in returns_tuple_map
|
||||
|
||||
sym_shape_indices = sym_shape_indices_map[submod_name]
|
||||
returns_tuple = returns_tuple_map[submod_name]
|
||||
runnables = compiled_callables[submod_name]
|
||||
|
||||
piecewise_backend = PiecewiseBackend(
|
||||
graph=None, # not needed for cached artifacts
|
||||
vllm_config=vllm_config,
|
||||
piecewise_compile_index=i,
|
||||
total_piecewise_compiles=len(submod_names),
|
||||
sym_shape_indices=sym_shape_indices,
|
||||
vllm_backend=vllm_backend,
|
||||
returns_tuple=returns_tuple,
|
||||
compiled_runnables=runnables,
|
||||
)
|
||||
|
||||
is_first = i == 0
|
||||
is_last = i == len(submod_names) - 1
|
||||
wrapped_backend = wrap_with_cudagraph_if_needed(
|
||||
piecewise_backend,
|
||||
vllm_config,
|
||||
compilation_config,
|
||||
is_first,
|
||||
is_last,
|
||||
)
|
||||
|
||||
split_gm.__dict__[submod_name] = wrapped_backend
|
||||
logger.debug(
|
||||
"Replaced submodule %s with piecewise backend from cache",
|
||||
submod_name,
|
||||
)
|
||||
|
||||
if compilation_config.cudagraph_copy_inputs:
|
||||
sym_tensor_indices = state["sym_tensor_indices"]
|
||||
input_buffers = [
|
||||
torch.empty_like(
|
||||
state["example_inputs"][idx], device=vllm_config.device_config.device
|
||||
)
|
||||
for idx in sym_tensor_indices
|
||||
]
|
||||
optimized_call = make_copy_and_call(sym_tensor_indices, input_buffers, split_gm)
|
||||
else:
|
||||
optimized_call = split_gm
|
||||
|
||||
fn = VllmSerializableFunction(
|
||||
**state,
|
||||
optimized_call=optimized_call,
|
||||
vllm_backend=None,
|
||||
)
|
||||
return fn
|
||||
|
||||
|
||||
def aot_compile_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||
factors = []
|
||||
# 0. factors come from the env, for example, The values of
|
||||
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
|
||||
@@ -163,6 +477,11 @@ def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
|
||||
# model is created)
|
||||
config_hash = vllm_config.compute_hash()
|
||||
factors.append(config_hash)
|
||||
|
||||
# 2. inductor factors if applicable
|
||||
if envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
factors.extend(get_inductor_factors())
|
||||
|
||||
return factors
|
||||
|
||||
|
||||
|
||||
@@ -16,9 +16,12 @@ import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CompilerInterface:
|
||||
"""
|
||||
@@ -230,12 +233,42 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config},
|
||||
)
|
||||
supports_aot = is_torch_equal_or_newer("2.10.0.dev")
|
||||
|
||||
if not supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
logger.error(
|
||||
"CRITICAL: VLLM_USE_MEGA_AOT_ARTIFACT "
|
||||
"is enabled but PyTorch version does not support 'aot' "
|
||||
"parameter in standalone_compile. This requires PyTorch "
|
||||
"2.10.0+. Falling back to non-AOT mode."
|
||||
)
|
||||
|
||||
compile_kwargs = {
|
||||
"dynamic_shapes": dynamic_shapes,
|
||||
"options": {
|
||||
"config_patches": current_config,
|
||||
},
|
||||
}
|
||||
|
||||
use_aot: bool = supports_aot and envs.VLLM_USE_MEGA_AOT_ARTIFACT
|
||||
# only add 'aot' parameter if both supported and enabled...
|
||||
# this will set bundled_autograd_cache
|
||||
# https://github.com/pytorch/pytorch/blob/9bbc5b2905c260adf41bc866a732f9c121a2828a/torch/_inductor/standalone_compile.py#L359 # noqa
|
||||
if use_aot:
|
||||
compile_kwargs["aot"] = True # type: ignore[assignment]
|
||||
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
from torch._inductor.standalone_compile import AOTCompiledArtifact
|
||||
|
||||
assert isinstance(compiled_graph, AOTCompiledArtifact)
|
||||
assert hasattr(compiled_graph, "serialize")
|
||||
# just return the compiled graph and a key
|
||||
# since we can serialize the bytes using to_bytes
|
||||
# and reload it using the key when reading
|
||||
return compiled_graph, None
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
@@ -619,7 +652,8 @@ def set_inductor_config(config: dict[str, Any], compile_range: Range) -> None:
|
||||
|
||||
|
||||
def set_functorch_config() -> None:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
if not envs.VLLM_USE_MEGA_AOT_ARTIFACT:
|
||||
torch._functorch.config.bundled_autograd_cache = False
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
|
||||
@@ -320,7 +320,7 @@ def _support_torch_compile(
|
||||
return
|
||||
|
||||
self._check_shape_invariants = shape_invariants
|
||||
|
||||
self.was_aot_compile_fn_loaded_from_disk = False
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
|
||||
@@ -417,9 +417,9 @@ def _support_torch_compile(
|
||||
serialized backend artifacts), then we need to generate a new AOT
|
||||
compile artifact from scratch.
|
||||
"""
|
||||
from .caching import compilation_config_hash_factors
|
||||
from .caching import aot_compile_hash_factors
|
||||
|
||||
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
|
||||
factors: list[str] = aot_compile_hash_factors(self.vllm_config)
|
||||
|
||||
factors.append(_model_hash_key(self.forward))
|
||||
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
@@ -446,6 +446,7 @@ def _support_torch_compile(
|
||||
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
self.was_aot_compile_fn_loaded_from_disk = True
|
||||
except Exception as e:
|
||||
if os.path.exists(aot_compilation_path):
|
||||
logger.warning(
|
||||
@@ -547,26 +548,45 @@ def _support_torch_compile(
|
||||
logger.warning("Detected eager backend, disabling AOT compile.")
|
||||
use_aot_compile = False
|
||||
if use_aot_compile:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
assert aot_compilation_path is not None
|
||||
assert cache_dir is not None
|
||||
try:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot save aot compilation to path %s, error: %s",
|
||||
aot_compilation_path,
|
||||
str(e),
|
||||
)
|
||||
from vllm.compilation.backends import set_on_compilation_complete
|
||||
|
||||
# store the path for saving after warmup
|
||||
self._aot_compilation_path = aot_compilation_path
|
||||
self._aot_cache_dir = cache_dir
|
||||
# set callback in context so it's available when compilation completes
|
||||
with set_on_compilation_complete(self.save_aot_compiled_function):
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
# triggers VllmSerializableFunction.serialize()
|
||||
def save_aot_compiled_function(self):
|
||||
if self.was_aot_compile_fn_loaded_from_disk:
|
||||
logger.debug("AOT compiled function was loaded from cache, skipping save")
|
||||
return
|
||||
|
||||
assert (
|
||||
self.aot_compiled_fn and self._aot_compilation_path and self._aot_cache_dir
|
||||
)
|
||||
|
||||
logger.info("saving AOT compiled function to %s", self._aot_compilation_path)
|
||||
try:
|
||||
os.makedirs(self._aot_cache_dir, exist_ok=True)
|
||||
self.aot_compiled_fn.save_compiled_function(self._aot_compilation_path)
|
||||
logger.info("saved AOT compiled function to %s", self._aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"unable to save AOT compiled function to %s: %s",
|
||||
self._aot_compilation_path,
|
||||
e,
|
||||
)
|
||||
|
||||
cls.__call__ = __call__
|
||||
cls.save_aot_compiled_function = save_aot_compiled_function
|
||||
return cls
|
||||
|
||||
|
||||
|
||||
@@ -2,10 +2,15 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import io
|
||||
import pickle
|
||||
from collections.abc import Callable
|
||||
from pickle import Pickler
|
||||
from typing import Any
|
||||
|
||||
import torch._functorch.config
|
||||
import torch.fx as fx
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
@@ -26,12 +31,14 @@ class RangeEntry:
|
||||
class PiecewiseBackend:
|
||||
def __init__(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
graph: fx.GraphModule | None,
|
||||
vllm_config: VllmConfig,
|
||||
piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int,
|
||||
sym_shape_indices: list[int],
|
||||
vllm_backend: VllmBackend,
|
||||
returns_tuple: bool,
|
||||
compiled_runnables: dict[str, Callable] | None = None,
|
||||
):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
@@ -41,13 +48,28 @@ class PiecewiseBackend:
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
This class supports two mutually exclusive modes:
|
||||
1. Compilation (graph is set, compiled_runnables is None):
|
||||
Used during initial compilation when we have the FX graph
|
||||
and need to compile it for each shape range.
|
||||
2. Precompilation (graph is None, compiled_runnables is set):
|
||||
Used when loading from cache/AOT artifacts where we already
|
||||
have pre-compiled callables and don't need the original graph.
|
||||
|
||||
Exactly one of graph or compiled_runnables must be provided.
|
||||
"""
|
||||
assert bool(graph is not None) ^ bool(compiled_runnables is not None), (
|
||||
"exactly one of graph and compiled_runnables should be set."
|
||||
)
|
||||
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
self.compiled_runnables = compiled_runnables
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = piecewise_compile_index == total_piecewise_compiles - 1
|
||||
@@ -77,6 +99,7 @@ class PiecewiseBackend:
|
||||
logger.debug_once(log_string)
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
self.returns_tuple = returns_tuple
|
||||
|
||||
# the entries for ranges that we need to either
|
||||
self.range_entries: dict[Range, RangeEntry] = {}
|
||||
@@ -108,12 +131,71 @@ class PiecewiseBackend:
|
||||
compile_range=range,
|
||||
)
|
||||
|
||||
# get the on_compilation_complete callback from context...
|
||||
# PiecewiseBackend is created during the first call,
|
||||
# which is when the context is set (see compilation/decorators.py)
|
||||
from vllm.compilation.backends import _on_compilation_complete_callback
|
||||
|
||||
self.on_compilation_complete = _on_compilation_complete_callback.get()
|
||||
|
||||
def get_compiled_graph_wrapper(self, compiled_graph):
|
||||
def compiled_graph_wrapper(*args):
|
||||
graph_output = compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if self.returns_tuple or not isinstance(graph_output, (tuple, list)):
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
def check_for_ending_compilation(self) -> None:
|
||||
if self.is_last_graph and not self.to_be_compiled_ranges:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
# Call the completion callback (e.g., to save AOT compiled function)
|
||||
if self.on_compilation_complete is not None:
|
||||
self.on_compilation_complete()
|
||||
|
||||
def to_bytes(self) -> dict[str, bytes]:
|
||||
class StandaloneCompiledArtifactsPickler(Pickler):
|
||||
def reducer_override(self, obj):
|
||||
if isinstance(obj, CachingAutotuner):
|
||||
obj.prepare_for_pickle()
|
||||
return pickle.loads, (
|
||||
pickle.dumps(
|
||||
obj,
|
||||
),
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def serialize(fn) -> bytes:
|
||||
assert hasattr(fn, "serialize"), "fn must have serialize method"
|
||||
with torch._functorch.config.patch("bundled_autograd_cache", True):
|
||||
entry = fn.serialize()
|
||||
|
||||
f = io.BytesIO()
|
||||
StandaloneCompiledArtifactsPickler(f).dump(entry)
|
||||
result = f.getvalue()
|
||||
return result
|
||||
|
||||
out = {}
|
||||
|
||||
for range_key, entry in self.range_entries.items():
|
||||
if not entry.compiled:
|
||||
logger.debug(
|
||||
"entry with range %s not compiled, so cannot get its bytes",
|
||||
range_key,
|
||||
)
|
||||
continue
|
||||
if hasattr(entry.runnable, "serialize"):
|
||||
out[str(range_key)] = serialize(entry.runnable)
|
||||
|
||||
return out
|
||||
|
||||
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
|
||||
# We need to pass fake example_inputs, otherwise torch.compile
|
||||
@@ -127,6 +209,7 @@ class PiecewiseBackend:
|
||||
# non fake tensors as example inputs!
|
||||
# See issue https://github.com/vllm-project/vllm/issues/27899
|
||||
fake_example_inputs = []
|
||||
assert self.graph is not None
|
||||
for node in self.graph.graph.nodes:
|
||||
# All place holders come first
|
||||
if node.op == "placeholder":
|
||||
@@ -140,28 +223,37 @@ class PiecewiseBackend:
|
||||
self, range_entry: RangeEntry, args: tuple[Any, ...]
|
||||
) -> Any:
|
||||
if not range_entry.compiled:
|
||||
if self.compiled_runnables is not None:
|
||||
range_entry.runnable = self.get_compiled_graph_wrapper(
|
||||
self.compiled_runnables[str(range_entry.compile_range)]
|
||||
)
|
||||
else:
|
||||
# args are real arguments
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args_list = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else list(args)
|
||||
)
|
||||
|
||||
with (
|
||||
torch._functorch.config.patch("bundled_autograd_cache", True),
|
||||
):
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
)
|
||||
|
||||
range_entry.compiled = True
|
||||
self.to_be_compiled_ranges.remove(range_entry.compile_range)
|
||||
|
||||
# args are real arguments
|
||||
# fakify for range, real args for concrete size.
|
||||
# For concrete size, we clear the shape env in
|
||||
# compiler_manager.compile() so no need to fakify.
|
||||
args_list = (
|
||||
self._fakify_args(args)
|
||||
if not range_entry.compile_range.is_single_size()
|
||||
else list(args)
|
||||
)
|
||||
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args_list,
|
||||
self.vllm_backend.inductor_config,
|
||||
self.compilation_config,
|
||||
compile_range=range_entry.compile_range,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
)
|
||||
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
|
||||
|
||||
Reference in New Issue
Block a user