diff --git a/vllm/compilation/caching.py b/vllm/compilation/caching.py index 3eda948b6..70fbaabb4 100644 --- a/vllm/compilation/caching.py +++ b/vllm/compilation/caching.py @@ -189,13 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] self.shape_env = None self.vllm_backend = vllm_backend self.sym_tensor_indices = sym_tensor_indices + self._fake_mode: Any | None = None import torch._functorch.config as functorch_config self.aot_autograd_config = ( aot_autograd_config or functorch_config.save_config_portable() ) - sym_input = next( (i for i in self.example_inputs if isinstance(i, torch.SymInt)), None ) @@ -217,6 +217,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] state.pop("optimized_call") state.pop("shape_env") state.pop("vllm_backend", None) + state.pop("_fake_mode", None) for node in state["graph_module"].graph.nodes: node.meta.pop("source_fn_stack", None) node.meta.pop("nn_module_stack", None) @@ -351,8 +352,31 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc] return fn.optimized_call(*example_inputs) fn = cls(**state, optimized_call=optimized_call) + fn._fake_mode = fake_mode return fn + def finalize_loading(self, vllm_config: VllmConfig) -> None: + """Eagerly initialize the compiled backend and perform all loading. + + Must be called after _verify_source_unchanged has populated + compilation_config.traced_files, which is needed for cache dir + computation. + """ + if self._fake_mode is None: + return # Already finalized, or mega path (no _fake_mode set) + + from torch._guards import TracingContext, tracing + + from vllm.compilation.backends import VllmBackend + + vllm_backend = VllmBackend(vllm_config, self.prefix, self.is_encoder) + with tracing(TracingContext(self._fake_mode)): + result = vllm_backend(self.graph_module, list(self.example_inputs)) + self.optimized_call = result.optimized_call + self.vllm_backend = vllm_backend + + self._fake_mode = None + @property def co_name(self) -> Literal["VllmSerializableFunction"]: """ diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index fe0984baf..f8629be34 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -30,7 +30,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils.import_utils import resolve_obj_by_qualname from vllm.utils.torch_utils import is_torch_equal_or_newer -from .monitor import start_monitoring_torch_compile +from .monitor import monitor_profiling_run, monitor_torch_compile if TYPE_CHECKING: # Only added on nightly/2.10 so wrap @@ -434,17 +434,24 @@ def _support_torch_compile( cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}") aot_compilation_path = os.path.join(cache_dir, "model") try: - with ( - set_current_vllm_config(self.vllm_config), - open(aot_compilation_path, "rb") as f, - ): - start_monitoring_torch_compile(self.vllm_config) - loaded_fn = torch.compiler.load_compiled_function( - f, f_globals=self.forward.__globals__ - ) - _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) - if not self.compilation_config.dynamic_shapes_config.evaluate_guards: - loaded_fn.disable_guard_check() + with monitor_torch_compile(self.vllm_config): + with ( + set_current_vllm_config(self.vllm_config), + open(aot_compilation_path, "rb") as f, + ): + loaded_fn = torch.compiler.load_compiled_function( + f, f_globals=self.forward.__globals__ + ) + _verify_source_unchanged(loaded_fn.source_info(), self.vllm_config) + ds_config = self.compilation_config.dynamic_shapes_config + if not ds_config.evaluate_guards: + loaded_fn.disable_guard_check() + # Eagerly load compiled artifacts now that traced_files + # is populated by _verify_source_unchanged. + with maybe_use_cudagraph_partition_wrapper(self.vllm_config): + loaded_fn._artifacts.compiled_fn.finalize_loading( + self.vllm_config + ) self.aot_compiled_fn = loaded_fn self.was_aot_compile_fn_loaded_from_disk = True except Exception as e: @@ -465,12 +472,11 @@ def _support_torch_compile( logger.info( "Directly load AOT compilation from path %s", aot_compilation_path ) - # Apply partition wrapper context for proper CUDA graph capture - from .monitor import end_monitoring_torch_compile - - with maybe_use_cudagraph_partition_wrapper(self.vllm_config): + with ( + monitor_profiling_run(), + maybe_use_cudagraph_partition_wrapper(self.vllm_config), + ): output = self.aot_compiled_fn(self, *args, **kwargs) - end_monitoring_torch_compile(self.vllm_config) return output if self.compiled: @@ -489,8 +495,6 @@ def _support_torch_compile( **kwargs, ) - # here, it is the starting point of the `torch.compile` process - start_monitoring_torch_compile(self.vllm_config) original_code_object = self.original_code_object() logger.debug("Start compiling function %s", original_code_object) @@ -559,16 +563,26 @@ def _support_torch_compile( # store the path for saving after warmup self._aot_compilation_path = aot_compilation_path self._aot_cache_dir = cache_dir - self.aot_compiled_fn = self.aot_compile(*args, **kwargs) - # All compilation is done at this point, save the AOT artifact. - self.save_aot_compiled_function() - output = self.aot_compiled_fn(self, *args, **kwargs) + with monitor_torch_compile(self.vllm_config): + self.aot_compiled_fn = self.aot_compile(*args, **kwargs) + # All compilation is done at this point, save the + # AOT artifact. + self.save_aot_compiled_function() + + with monitor_profiling_run(): + output = self.aot_compiled_fn(self, *args, **kwargs) else: - output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type] + with monitor_torch_compile( + self.vllm_config, + "torch.compile and initial profiling/warmup " + "run together took %.2f s in total", + ): + output = TorchCompileWithNoGuardsWrapper.__call__( + self, # type: ignore[arg-type] + *args, + **kwargs, + ) - from .monitor import end_monitoring_torch_compile - - end_monitoring_torch_compile(self.vllm_config) self.compiled = True return output diff --git a/vllm/compilation/monitor.py b/vllm/compilation/monitor.py index fb9dfa3ac..f584f526f 100644 --- a/vllm/compilation/monitor.py +++ b/vllm/compilation/monitor.py @@ -1,46 +1,83 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib import time +from collections.abc import Generator -from vllm.config import CompilationConfig, CompilationMode, VllmConfig +from vllm.config import CompilationMode, VllmConfig from vllm.logger import init_logger logger = init_logger(__name__) -context_manager = None +# Shared global so backends.py can read the start time for Dynamo timing. torch_compile_start_time: float = 0.0 -def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None: +@contextlib.contextmanager +def monitor_torch_compile( + vllm_config: VllmConfig, + message: str = "torch.compile took %.2f s in total", +) -> Generator[None, None, None]: + """Context manager that times torch.compile and manages depyf debugging. + + On normal exit: logs the compile time and exits depyf. + On exception: cleans up depyf without logging (compilation failed). + """ global torch_compile_start_time torch_compile_start_time = time.perf_counter() - compilation_config: CompilationConfig = vllm_config.compilation_config + compilation_config = vllm_config.compilation_config + depyf_cm = None path = vllm_config.compile_debug_dump_path() if compilation_config.mode == CompilationMode.VLLM_COMPILE and path: import depyf path.mkdir(parents=True, exist_ok=True) logger.debug("Dumping depyf output to %s", path) - global context_manager - context_manager = depyf.prepare_debug(path.as_posix()) - context_manager.__enter__() + depyf_cm = depyf.prepare_debug(path.as_posix()) + depyf_cm.__enter__() + + try: + yield + except Exception: + raise + else: + total_compile_time = time.perf_counter() - torch_compile_start_time + if compilation_config.mode == CompilationMode.VLLM_COMPILE: + logger.info_once(message, total_compile_time, scope="local") + finally: + if depyf_cm is not None: + try: + depyf_cm.__exit__(None, None, None) + except Exception: + logger.warning("Exception during depyf cleanup.", exc_info=True) -def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None: - compilation_config: CompilationConfig = vllm_config.compilation_config - total_compile_time: float = time.perf_counter() - torch_compile_start_time - if compilation_config.mode == CompilationMode.VLLM_COMPILE: - logger.info_once( - "torch.compile and initial profiling run took %.2f s in total", - total_compile_time, - scope="local", - ) - global context_manager - if context_manager is not None: - context_manager.__exit__(None, None, None) - context_manager = None +@contextlib.contextmanager +def monitor_profiling_run() -> Generator[None, None, None]: + """Context manager that times the initial profiling run. + + Asserts that no backend compilation occurs during the profiling run + (all compilation should have completed before this point). + """ + from vllm.compilation.counter import compilation_counter + + backend_compilations_before = compilation_counter.num_backend_compilations + start = time.perf_counter() + yield + elapsed = time.perf_counter() - start + assert ( + compilation_counter.num_backend_compilations == backend_compilations_before + ), ( + "backend compilation occurred during the initial profiling run; " + "all compilation should be complete before the profiling run starts." + ) + logger.info_once( + "Initial profiling/warmup run took %.2f s", + elapsed, + scope="local", + ) cudagraph_capturing_enabled: bool = True