[torch.compile] Stop lazily compiling (#35472)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-03-04 15:13:17 -05:00
committed by GitHub
parent 138d891d7f
commit 5569f5218d
7 changed files with 177 additions and 150 deletions

View File

@@ -73,6 +73,7 @@ def test_compile_ranges(use_fresh_inductor_cache):
Range(start=16, end=16),
Range(start=9, end=32),
Range(start=64, end=64),
Range(start=128, end=128),
Range(start=33, end=8192),
]
)
@@ -95,16 +96,16 @@ def test_compile_ranges(use_fresh_inductor_cache):
with set_current_vllm_config(vllm_config):
model = TestModel(vllm_config=vllm_config, prefix="").eval()
# Number of compilations: 3 for each compile range + 2 compile sizes
# Number of compilations: 3 compile ranges + 3 compile sizes
batch_sizes = [1, 4, 16, 24, 48, 64, 8192]
with compilation_counter.expect(
num_graphs_seen=1,
num_piecewise_graphs_seen=1,
num_backend_compilations=5,
num_backend_compilations=6,
):
run_model(vllm_config, model, batch_sizes)
assert post_grad_range_checker.num_calls == 5
assert post_grad_range_checker.num_calls == 6
def test_compile_config_get_compile_ranges():

View File

@@ -109,9 +109,9 @@ def test_vllm_structured_logging_artifacts(use_fresh_inductor_cache):
f"got {len(vllm_piecewise_split_graph)}"
)
compile_start_artifacts = capture.get("artifact", "vllm_piecewise_compile_start")
assert len(compile_start_artifacts) == 2, (
"Expected 2 vllm_piecewise_compile_start "
"(one for dynamic ranges, one for compile size), "
assert len(compile_start_artifacts) == 4, (
"Expected 4 vllm_piecewise_compile_start "
"(2 subgraphs x 2 ranges each: dynamic + compile size), "
f"got {len(compile_start_artifacts)}"
)
submod_dumps = capture.get("graph_dump", r"vllm_submod_.*")

View File

@@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
import contextvars
import dataclasses
import hashlib
import json
@@ -18,7 +17,7 @@ from typing import Any
import torch
import torch.fx as fx
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import dynamo_timed
from torch._logging._internal import trace_structured
import vllm.envs as envs
@@ -510,9 +509,9 @@ def wrap_with_cudagraph_if_needed(
class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
"""Code adapted from `torch.fx.passes.shape_prop.ShapeProp`.
It runs the given graph with fake inputs, and compile some
submodules specified by `compile_submod_names` with the given
compilation configs.
It runs the given split graph interpreter, and for each submodule in
`compile_submod_names`, creates a PiecewiseBackend and compiles all
ranges up front.
NOTE: the order in `compile_submod_names` matters, because
it will be used to determine the order of the compiled piecewise
@@ -540,9 +539,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
vllm_backend: "VllmBackend",
) -> None:
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.vllm_config = vllm_config
@@ -552,13 +548,7 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
@instrument(span_name="Inductor compilation")
def run(self, *args: Any) -> Any:
# maybe instead just assert inputs are fake?
fake_args = [
self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t
for t in args
]
with self.fake_mode, enable_python_dispatcher():
return super().run(*fake_args)
return super().run(*args)
def call_module(
self,
@@ -614,21 +604,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter): # type: ignore[misc]
model_tag: str = "backbone"
model_is_encoder: bool = False
_on_compilation_complete_callback: contextvars.ContextVar[Callable[[], None] | None] = (
contextvars.ContextVar("on_compilation_complete_callback", default=None)
)
@contextmanager
def set_on_compilation_complete(
callback: Callable[[], None],
) -> Generator[None, None, None]:
token = _on_compilation_complete_callback.set(callback)
try:
yield
finally:
_on_compilation_complete_callback.reset(token)
@contextmanager
def set_model_tag(tag: str, is_encoder: bool = False) -> Generator[None, None, None]:
@@ -846,6 +821,7 @@ class VllmBackend:
),
)
@dynamo_timed("vllm_backend")
def __call__(self, graph: fx.GraphModule, example_inputs: Sequence[Any]) -> Any:
from .caching import (
VllmSerializableFunction,
@@ -1036,11 +1012,24 @@ class VllmBackend:
]
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
# compile submodules with symbolic shapes, and compile all ranges
# up front so that compilation is complete before the callable
# is returned.
PiecewiseCompileInterpreter(
self.split_gm, submod_names_to_compile, self.vllm_config, self
).run(*fake_args)
# All compilation is done. Save the cache.
time_before_saving = time.perf_counter()
self.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)
from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode()

View File

@@ -313,30 +313,26 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
return fn
# Fall back to standard VllmBackend
# Fall back to standard VllmBackend.
# Use a lazy closure: the backend needs traced_files for cache
# dir computation, but those are only populated after
# _verify_source_unchanged runs in decorators.py (which happens
# after deserialization completes).
from vllm.compilation.backends import VllmBackend
is_encoder = state.get("is_encoder", False)
vllm_backend: VllmBackend = VllmBackend(
get_current_vllm_config(), state["prefix"], is_encoder
)
vllm_config = get_current_vllm_config()
compile_inputs = list(state["example_inputs"])
def optimized_call(*example_inputs: Any) -> Any:
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
call returns, we just do a one-time replacement of the optimized
call with the compiled function, so that subsequent calls are on
the AOT compiled path.
"""
compile_inputs = [
inp if inp is not None else example_inputs[i]
for i, inp in enumerate(fn.example_inputs)
]
vllm_backend: VllmBackend = VllmBackend(
vllm_config, state["prefix"], is_encoder
)
with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs
).optimized_call
fn.vllm_backend = vllm_backend
return fn.optimized_call(*example_inputs)
fn = cls(**state, optimized_call=optimized_call)

View File

@@ -466,8 +466,12 @@ def _support_torch_compile(
"Directly load AOT compilation from path %s", aot_compilation_path
)
# Apply partition wrapper context for proper CUDA graph capture
from .monitor import end_monitoring_torch_compile
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
return self.aot_compiled_fn(self, *args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
end_monitoring_torch_compile(self.vllm_config)
return output
if self.compiled:
assert (
@@ -552,18 +556,19 @@ def _support_torch_compile(
logger.warning("Detected eager backend, disabling AOT compile.")
use_aot_compile = False
if use_aot_compile:
from vllm.compilation.backends import set_on_compilation_complete
# store the path for saving after warmup
self._aot_compilation_path = aot_compilation_path
self._aot_cache_dir = cache_dir
# set callback in context so it's available when compilation completes
with set_on_compilation_complete(self.save_aot_compiled_function):
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
# All compilation is done at this point, save the AOT artifact.
self.save_aot_compiled_function()
output = self.aot_compiled_fn(self, *args, **kwargs)
else:
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
from .monitor import end_monitoring_torch_compile
end_monitoring_torch_compile(self.vllm_config)
self.compiled = True
return output

View File

@@ -33,7 +33,7 @@ def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
total_compile_time: float = time.perf_counter() - torch_compile_start_time
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
logger.info_once(
"torch.compile takes %.2f s in total",
"torch.compile and initial profiling run took %.2f s in total",
total_compile_time,
scope="local",
)

View File

@@ -5,7 +5,6 @@ import dataclasses
import io
import json
import pickle
import time
from collections.abc import Callable
from pickle import Pickler
from typing import Any
@@ -16,7 +15,6 @@ from torch._inductor.runtime.triton_heuristics import CachingAutotuner
from torch._logging._internal import trace_structured
from vllm.compilation.backends import VllmBackend
from vllm.compilation.monitor import end_monitoring_torch_compile
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.logger import init_logger
@@ -24,6 +22,55 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
def get_fake_args_from_graph(graph: fx.GraphModule) -> list[Any]:
"""Get fake args directly from graph placeholder nodes."""
fake_args = []
for node in graph.graph.nodes:
if node.op == "placeholder":
fake_args.append(node.meta["example_value"])
else:
break
return fake_args
def create_concrete_args(graph: fx.GraphModule, size: int) -> list[Any]:
"""Create example inputs with symbolic dims replaced by a concrete size.
Used for single-size eager compilation where we need concrete-shaped
inputs but don't have real runtime tensors yet.
"""
from torch._prims_common import compute_required_storage_length
from torch.fx.experimental.symbolic_shapes import is_symbolic
def concretize(sym_val: Any) -> int:
"""Replace all symbolic variables in a SymInt expression with size."""
if not is_symbolic(sym_val):
return int(sym_val)
expr = sym_val.node.expr
return int(expr.subs({s: size for s in expr.free_symbols}))
args: list[Any] = []
for node in graph.graph.nodes:
if node.op != "placeholder":
break
val = node.meta["example_value"]
if isinstance(val, torch.SymInt):
args.append(concretize(val))
elif isinstance(val, torch.Tensor):
new_shape = tuple(concretize(d) for d in val.shape)
new_strides = tuple(concretize(s) for s in val.stride())
new_storage_offset = concretize(val.storage_offset())
needed_size = compute_required_storage_length(
new_shape, new_strides, new_storage_offset
)
t = torch.empty(needed_size, dtype=val.dtype, device=val.device)
t = t.as_strided(new_shape, new_strides, new_storage_offset)
args.append(t)
else:
args.append(val)
return args
@dataclasses.dataclass
class RangeEntry:
compile_range: Range
@@ -109,10 +156,6 @@ class PiecewiseBackend:
# the entries for ranges that we need to either
self.range_entries: dict[Range, RangeEntry] = {}
# to_be_compiled_ranges tracks the remaining ranges to compile,
# and updates during the compilation process, so we need to copy it
self.to_be_compiled_ranges: set[Range] = set(self.compile_ranges)
# We only keep compilation management inside this class directly.
if self.compile_sizes is not None:
for size in self.compile_sizes:
@@ -129,7 +172,6 @@ class PiecewiseBackend:
self.range_entries[range] = RangeEntry(
compile_range=range,
)
self.to_be_compiled_ranges.add(range)
for range in self.compile_ranges:
self.range_entries[range] = RangeEntry(
@@ -139,12 +181,10 @@ class PiecewiseBackend:
# Track whether we've logged the graph for this subgraph (only log once)
self._graph_logged = False
# get the on_compilation_complete callback from context...
# PiecewiseBackend is created during the first call,
# which is when the context is set (see compilation/decorators.py)
from vllm.compilation.backends import _on_compilation_complete_callback
self.on_compilation_complete = _on_compilation_complete_callback.get()
if self.graph is not None:
self.compile_all_ranges()
else:
self.load_all_ranges()
def get_compiled_graph_wrapper(
self, compiled_graph: Callable[..., Any]
@@ -161,25 +201,6 @@ class PiecewiseBackend:
return compiled_graph_wrapper
def check_for_ending_compilation(self) -> None:
if self.is_last_graph and not self.to_be_compiled_ranges:
# no specific sizes to compile
# save the hash of the inductor graph for the next run
time_before_saving = time.perf_counter()
self.vllm_backend.compiler_manager.save_to_file()
elapsed = time.perf_counter() - time_before_saving
if elapsed > 1:
logger.info_once(
"Saved compiler manager cache in %.2f seconds.",
elapsed,
scope="local",
)
end_monitoring_torch_compile(self.vllm_config)
# Call the completion callback (e.g., to save AOT compiled function)
if self.on_compilation_complete is not None:
self.on_compilation_complete()
def to_bytes(self) -> dict[str, bytes]:
class StandaloneCompiledArtifactsPickler(Pickler):
def reducer_override(self, obj: object) -> Any:
@@ -216,27 +237,54 @@ class PiecewiseBackend:
return out
def _fakify_args(self, args: tuple[Any, ...]) -> list[Any]:
# We need to pass fake example_inputs, otherwise torch.compile
# will fakify the example_inputs potentially causing some non dynamic
# dimension to be be duck shaped to other existing shapes that have hints
# matching their values.
# This is problem because it can lead to unintended specializations!
# if the new wrongly dynamic dim is specialized
# it will force specializing the whole shape
# torch.compile probably should not accept
# non fake tensors as example inputs!
# See issue https://github.com/vllm-project/vllm/issues/27899
fake_example_inputs = []
assert self.graph is not None
for node in self.graph.graph.nodes:
# All place holders come first
if node.op == "placeholder":
fake_example_inputs.append(node.meta["example_value"])
def compile_all_ranges(self) -> None:
"""Compile all range entries for this piecewise subgraph up front."""
assert self.graph is not None, (
"Cannot compile without a graph. "
"When loading from cache/AOT artifacts, "
"compile_all_ranges should not be called."
)
for range_entry in self.range_entries.values():
if range_entry.compiled:
continue
self._log_compile_start(range_entry.compile_range)
if range_entry.compile_range.is_single_size():
args_list = create_concrete_args(
self.graph, range_entry.compile_range.start
)
else:
break
assert len(fake_example_inputs) == len(args)
return fake_example_inputs
args_list = get_fake_args_from_graph(self.graph)
# TODO(https://github.com/vllm-project/vllm/issues/35766)
# Can we remove strict_autograd_cache and
# force_non_lazy_backward_lowering overrides?
# I added them explicitly because this is what they are
# set to before the refactor
# (https://github.com/vllm-project/vllm/pull/35472).
# They affect the aotautograd cache key computation
# but they shouldn't have any effect on the actual
# compilation.
config_patches = dict(
bundled_autograd_cache=True,
strict_autograd_cache=False,
)
if hasattr(torch._functorch.config, "force_non_lazy_backward_lowering"):
config_patches["force_non_lazy_backward_lowering"] = False
with torch._functorch.config.patch(**config_patches):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
range_entry.compiled = True
def _log_compile_start(self, compile_range: Range):
"""Log compilation event for TORCH_TRACE/tlparse."""
@@ -277,44 +325,29 @@ class PiecewiseBackend:
payload_fn=lambda: self.graph.print_readable(print_output=False),
)
def _maybe_compile_for_range_entry(
self, range_entry: RangeEntry, args: tuple[Any, ...]
) -> Any:
if not range_entry.compiled:
if self.compiled_runnables is not None:
range_entry.runnable = self.get_compiled_graph_wrapper(
self.compiled_runnables[str(range_entry.compile_range)]
)
else:
self._log_compile_start(range_entry.compile_range)
# args are real arguments
# fakify for range, real args for concrete size.
# For concrete size, we clear the shape env in
# compiler_manager.compile() so no need to fakify.
args_list = (
self._fakify_args(args)
if not range_entry.compile_range.is_single_size()
else list(args)
)
with (
torch._functorch.config.patch("bundled_autograd_cache", True),
):
range_entry.runnable = self.vllm_backend.compiler_manager.compile(
self.graph,
args_list,
self.vllm_backend.inductor_config,
self.compilation_config,
compile_range=range_entry.compile_range,
graph_index=self.piecewise_compile_index,
num_graphs=self.total_piecewise_compiles,
)
def load_all_ranges(self) -> None:
"""Load all pre-compiled runnables for this piecewise subgraph.
Called during warm start to wrap all cached compiled_runnables
into range_entry.runnable up front, analogous to compile_all_ranges()
for the cold start path.
"""
assert self.compiled_runnables is not None, (
"load_all_ranges should only be called when compiled_runnables "
"is set (warm start / cache loading path)."
)
for range_entry in self.range_entries.values():
if range_entry.compiled:
continue
key = str(range_entry.compile_range)
assert key in self.compiled_runnables, (
f"Missing compiled runnable for range {range_entry.compile_range}. "
f"Available keys: {list(self.compiled_runnables.keys())}"
)
range_entry.runnable = self.get_compiled_graph_wrapper(
self.compiled_runnables[key]
)
range_entry.compiled = True
self.to_be_compiled_ranges.remove(range_entry.compile_range)
self.check_for_ending_compilation()
def _find_range_for_shape(self, runtime_shape: int) -> RangeEntry | None:
# First we try to find the range entry for the concrete compile size
@@ -338,6 +371,9 @@ class PiecewiseBackend:
assert range_entry is not None, (
f"Shape: {runtime_shape} out of considered ranges: {self.compile_ranges}"
)
self._maybe_compile_for_range_entry(range_entry, args)
assert range_entry.compiled, (
"All ranges should be compiled or loaded up front in "
"PiecewiseBackend.__init__. "
f"range_entry={range_entry.compile_range}"
)
return range_entry.runnable(*args)