[compile] Split compile/warmup monitoring (#36098)
This commit is contained in:
@@ -189,13 +189,13 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
self.shape_env = None
|
||||
self.vllm_backend = vllm_backend
|
||||
self.sym_tensor_indices = sym_tensor_indices
|
||||
self._fake_mode: Any | None = None
|
||||
|
||||
import torch._functorch.config as functorch_config
|
||||
|
||||
self.aot_autograd_config = (
|
||||
aot_autograd_config or functorch_config.save_config_portable()
|
||||
)
|
||||
|
||||
sym_input = next(
|
||||
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
|
||||
)
|
||||
@@ -217,6 +217,7 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
state.pop("optimized_call")
|
||||
state.pop("shape_env")
|
||||
state.pop("vllm_backend", None)
|
||||
state.pop("_fake_mode", None)
|
||||
for node in state["graph_module"].graph.nodes:
|
||||
node.meta.pop("source_fn_stack", None)
|
||||
node.meta.pop("nn_module_stack", None)
|
||||
@@ -351,8 +352,31 @@ class VllmSerializableFunction(SerializableCallable): # type: ignore[misc]
|
||||
return fn.optimized_call(*example_inputs)
|
||||
|
||||
fn = cls(**state, optimized_call=optimized_call)
|
||||
fn._fake_mode = fake_mode
|
||||
return fn
|
||||
|
||||
def finalize_loading(self, vllm_config: VllmConfig) -> None:
|
||||
"""Eagerly initialize the compiled backend and perform all loading.
|
||||
|
||||
Must be called after _verify_source_unchanged has populated
|
||||
compilation_config.traced_files, which is needed for cache dir
|
||||
computation.
|
||||
"""
|
||||
if self._fake_mode is None:
|
||||
return # Already finalized, or mega path (no _fake_mode set)
|
||||
|
||||
from torch._guards import TracingContext, tracing
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
|
||||
vllm_backend = VllmBackend(vllm_config, self.prefix, self.is_encoder)
|
||||
with tracing(TracingContext(self._fake_mode)):
|
||||
result = vllm_backend(self.graph_module, list(self.example_inputs))
|
||||
self.optimized_call = result.optimized_call
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self._fake_mode = None
|
||||
|
||||
@property
|
||||
def co_name(self) -> Literal["VllmSerializableFunction"]:
|
||||
"""
|
||||
|
||||
@@ -30,7 +30,7 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
from .monitor import start_monitoring_torch_compile
|
||||
from .monitor import monitor_profiling_run, monitor_torch_compile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# Only added on nightly/2.10 so wrap
|
||||
@@ -434,17 +434,24 @@ def _support_torch_compile(
|
||||
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
|
||||
aot_compilation_path = os.path.join(cache_dir, "model")
|
||||
try:
|
||||
with (
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
):
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
loaded_fn = torch.compiler.load_compiled_function(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
if not self.compilation_config.dynamic_shapes_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
with monitor_torch_compile(self.vllm_config):
|
||||
with (
|
||||
set_current_vllm_config(self.vllm_config),
|
||||
open(aot_compilation_path, "rb") as f,
|
||||
):
|
||||
loaded_fn = torch.compiler.load_compiled_function(
|
||||
f, f_globals=self.forward.__globals__
|
||||
)
|
||||
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
|
||||
ds_config = self.compilation_config.dynamic_shapes_config
|
||||
if not ds_config.evaluate_guards:
|
||||
loaded_fn.disable_guard_check()
|
||||
# Eagerly load compiled artifacts now that traced_files
|
||||
# is populated by _verify_source_unchanged.
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
loaded_fn._artifacts.compiled_fn.finalize_loading(
|
||||
self.vllm_config
|
||||
)
|
||||
self.aot_compiled_fn = loaded_fn
|
||||
self.was_aot_compile_fn_loaded_from_disk = True
|
||||
except Exception as e:
|
||||
@@ -465,12 +472,11 @@ def _support_torch_compile(
|
||||
logger.info(
|
||||
"Directly load AOT compilation from path %s", aot_compilation_path
|
||||
)
|
||||
# Apply partition wrapper context for proper CUDA graph capture
|
||||
from .monitor import end_monitoring_torch_compile
|
||||
|
||||
with maybe_use_cudagraph_partition_wrapper(self.vllm_config):
|
||||
with (
|
||||
monitor_profiling_run(),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
):
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
return output
|
||||
|
||||
if self.compiled:
|
||||
@@ -489,8 +495,6 @@ def _support_torch_compile(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
original_code_object = self.original_code_object()
|
||||
logger.debug("Start compiling function %s", original_code_object)
|
||||
|
||||
@@ -559,16 +563,26 @@ def _support_torch_compile(
|
||||
# store the path for saving after warmup
|
||||
self._aot_compilation_path = aot_compilation_path
|
||||
self._aot_cache_dir = cache_dir
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
# All compilation is done at this point, save the AOT artifact.
|
||||
self.save_aot_compiled_function()
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
with monitor_torch_compile(self.vllm_config):
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
# All compilation is done at this point, save the
|
||||
# AOT artifact.
|
||||
self.save_aot_compiled_function()
|
||||
|
||||
with monitor_profiling_run():
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs) # type: ignore[arg-type]
|
||||
with monitor_torch_compile(
|
||||
self.vllm_config,
|
||||
"torch.compile and initial profiling/warmup "
|
||||
"run together took %.2f s in total",
|
||||
):
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(
|
||||
self, # type: ignore[arg-type]
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
from .monitor import end_monitoring_torch_compile
|
||||
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
|
||||
@@ -1,46 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
|
||||
from vllm.config import CompilationConfig, CompilationMode, VllmConfig
|
||||
from vllm.config import CompilationMode, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
context_manager = None
|
||||
# Shared global so backends.py can read the start time for Dynamo timing.
|
||||
torch_compile_start_time: float = 0.0
|
||||
|
||||
|
||||
def start_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
@contextlib.contextmanager
|
||||
def monitor_torch_compile(
|
||||
vllm_config: VllmConfig,
|
||||
message: str = "torch.compile took %.2f s in total",
|
||||
) -> Generator[None, None, None]:
|
||||
"""Context manager that times torch.compile and manages depyf debugging.
|
||||
|
||||
On normal exit: logs the compile time and exits depyf.
|
||||
On exception: cleans up depyf without logging (compilation failed).
|
||||
"""
|
||||
global torch_compile_start_time
|
||||
torch_compile_start_time = time.perf_counter()
|
||||
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
compilation_config = vllm_config.compilation_config
|
||||
depyf_cm = None
|
||||
path = vllm_config.compile_debug_dump_path()
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE and path:
|
||||
import depyf
|
||||
|
||||
path.mkdir(parents=True, exist_ok=True)
|
||||
logger.debug("Dumping depyf output to %s", path)
|
||||
global context_manager
|
||||
context_manager = depyf.prepare_debug(path.as_posix())
|
||||
context_manager.__enter__()
|
||||
depyf_cm = depyf.prepare_debug(path.as_posix())
|
||||
depyf_cm.__enter__()
|
||||
|
||||
try:
|
||||
yield
|
||||
except Exception:
|
||||
raise
|
||||
else:
|
||||
total_compile_time = time.perf_counter() - torch_compile_start_time
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info_once(message, total_compile_time, scope="local")
|
||||
finally:
|
||||
if depyf_cm is not None:
|
||||
try:
|
||||
depyf_cm.__exit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Exception during depyf cleanup.", exc_info=True)
|
||||
|
||||
|
||||
def end_monitoring_torch_compile(vllm_config: VllmConfig) -> None:
|
||||
compilation_config: CompilationConfig = vllm_config.compilation_config
|
||||
total_compile_time: float = time.perf_counter() - torch_compile_start_time
|
||||
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
|
||||
logger.info_once(
|
||||
"torch.compile and initial profiling run took %.2f s in total",
|
||||
total_compile_time,
|
||||
scope="local",
|
||||
)
|
||||
global context_manager
|
||||
if context_manager is not None:
|
||||
context_manager.__exit__(None, None, None)
|
||||
context_manager = None
|
||||
@contextlib.contextmanager
|
||||
def monitor_profiling_run() -> Generator[None, None, None]:
|
||||
"""Context manager that times the initial profiling run.
|
||||
|
||||
Asserts that no backend compilation occurs during the profiling run
|
||||
(all compilation should have completed before this point).
|
||||
"""
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
|
||||
backend_compilations_before = compilation_counter.num_backend_compilations
|
||||
start = time.perf_counter()
|
||||
yield
|
||||
elapsed = time.perf_counter() - start
|
||||
assert (
|
||||
compilation_counter.num_backend_compilations == backend_compilations_before
|
||||
), (
|
||||
"backend compilation occurred during the initial profiling run; "
|
||||
"all compilation should be complete before the profiling run starts."
|
||||
)
|
||||
logger.info_once(
|
||||
"Initial profiling/warmup run took %.2f s",
|
||||
elapsed,
|
||||
scope="local",
|
||||
)
|
||||
|
||||
|
||||
cudagraph_capturing_enabled: bool = True
|
||||
|
||||
Reference in New Issue
Block a user