355 lines
14 KiB
Python
355 lines
14 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import os
|
|
import sys
|
|
from abc import abstractmethod
|
|
from collections.abc import Callable, Generator
|
|
from contextlib import contextmanager, nullcontext
|
|
from types import CodeType
|
|
from typing import Any, ParamSpec, TypeVar
|
|
|
|
import torch
|
|
|
|
import vllm.envs as envs
|
|
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
|
from vllm.config.compilation import DynamicShapesType
|
|
from vllm.logger import init_logger
|
|
from vllm.utils.nvtx_pytorch_hooks import layerwise_nvtx_marker_context
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
R = TypeVar("R")
|
|
P = ParamSpec("P")
|
|
|
|
|
|
@contextmanager
|
|
def _compilation_context() -> Generator[None, None, None]:
|
|
"""Context manager for compilation settings.
|
|
|
|
This manager sets higher dynamo cache limits for compilation.
|
|
(Needed for qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
|
Generally a recompilation can happen whenever we use a new
|
|
backend instance in torch.compile.
|
|
"""
|
|
original_cache_size = torch._dynamo.config.cache_size_limit
|
|
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
|
|
|
try:
|
|
torch._dynamo.config.cache_size_limit = 2048
|
|
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
|
yield
|
|
finally:
|
|
torch._dynamo.config.cache_size_limit = original_cache_size
|
|
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
|
|
|
|
|
class TorchCompileWithNoGuardsWrapper:
|
|
"""
|
|
A wrapper class for torch.compile, it ensures that all guards are dropped
|
|
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
|
When guards are dropped, the first time __call__ is invoked, a single
|
|
compilation is triggered. Dynamo should never be traced again after that
|
|
since we drop all guards.
|
|
"""
|
|
|
|
def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any:
|
|
assert hasattr(self, "_check_shape_invariants")
|
|
self._check_shape_invariants(*args, **kwargs)
|
|
|
|
return self.forward(*args, **kwargs)
|
|
|
|
def _call_with_optional_nvtx_range(
|
|
self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs
|
|
) -> Any:
|
|
if self.layerwise_nvtx_tracing_enabled:
|
|
args_list = list(args)
|
|
kwargs_dict = dict(kwargs)
|
|
with layerwise_nvtx_marker_context(
|
|
"Torch Compiled Module (input):{}".format(self.__class__.__name__),
|
|
self,
|
|
in_tensor=args_list,
|
|
kwargs=kwargs_dict,
|
|
) as ctx:
|
|
ctx.result = callable_fn(*args, **kwargs)
|
|
return ctx.result
|
|
return callable_fn(*args, **kwargs)
|
|
|
|
def __init__(
|
|
self,
|
|
compile_prefix: str = "",
|
|
is_encoder: bool = False,
|
|
) -> None:
|
|
self.compiled = False
|
|
self._compile_prefix = compile_prefix
|
|
self._is_encoder = is_encoder
|
|
|
|
vllm_config = get_current_vllm_config()
|
|
self.vllm_config = vllm_config
|
|
mode = vllm_config.compilation_config.mode
|
|
self.layerwise_nvtx_tracing_enabled = (
|
|
vllm_config.observability_config.enable_layerwise_nvtx_tracing
|
|
)
|
|
if mode is None:
|
|
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
|
|
|
backend = vllm_config.compilation_config.init_backend(
|
|
vllm_config, prefix=compile_prefix, is_encoder=is_encoder
|
|
)
|
|
options = {}
|
|
|
|
if isinstance(backend, str) and backend == "inductor":
|
|
options = vllm_config.compilation_config.inductor_compile_config
|
|
|
|
self.first_compile = True
|
|
self.evaluate_guards = (
|
|
vllm_config.compilation_config.dynamic_shapes_config.evaluate_guards
|
|
)
|
|
|
|
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
|
|
|
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
|
# Drop all the guards.
|
|
if self.evaluate_guards:
|
|
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
|
"compilation_config.dynamic_shapes_config.evaluate_guards "
|
|
"requires VLLM_USE_BYTECODE_HOOK=0. "
|
|
)
|
|
|
|
options["guard_filter_fn"] = lambda x: [
|
|
entry.guard_type == "SHAPE_ENV" for entry in x
|
|
]
|
|
else:
|
|
if hasattr(torch.compiler, "skip_all_guards_unsafe"):
|
|
# Torch 2.10+ provides skip_all_guards_unsafe
|
|
options["guard_filter_fn"] = torch.compiler.skip_all_guards_unsafe
|
|
else:
|
|
# Equivalent fallback for older PyTorch: skip all guards
|
|
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
|
|
|
compiled_ptr: Any = self.forward
|
|
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
|
|
|
if ds_type == DynamicShapesType.UNBACKED:
|
|
# reason is that bytecode does torch._dynamo.eval_frame.
|
|
# remove_from_cache(self.original_code_object()) to force a new
|
|
# re-compilation. And if we use
|
|
# compiled_ptr = self.check_invariants_and_forward
|
|
# it will reset all entries.
|
|
assert not envs.VLLM_USE_BYTECODE_HOOK, (
|
|
"UNBACKED dynamic shapes requires VLLM_USE_BYTECODE_HOOK=0. "
|
|
)
|
|
assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards"
|
|
|
|
compiled_ptr = self.check_invariants_and_forward
|
|
|
|
# Apply the constrain_to_fx_strides patch before first compilation.
|
|
# This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The VLLM
|
|
# compile paths call this from their own compile() methods too.
|
|
from vllm.env_override import _apply_constrain_to_fx_strides_patch
|
|
|
|
_apply_constrain_to_fx_strides_patch()
|
|
|
|
aot_context = nullcontext()
|
|
if envs.VLLM_USE_AOT_COMPILE:
|
|
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
|
aot_context = torch._dynamo.config.patch(enable_aot_compile=True)
|
|
else:
|
|
msg = "torch._dynamo.config.enable_aot_compile is not "
|
|
msg += "available. AOT compile is disabled and please "
|
|
msg += "upgrade PyTorch version to use AOT compile."
|
|
logger.warning(msg)
|
|
|
|
with aot_context:
|
|
self._compiled_callable = torch.compile(
|
|
compiled_ptr,
|
|
fullgraph=True,
|
|
dynamic=False,
|
|
backend=backend,
|
|
options=options,
|
|
)
|
|
|
|
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
|
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
|
self._compiled_bytecode: CodeType | None = None
|
|
|
|
def aot_compile(self, *args: Any, **kwargs: Any) -> Any:
|
|
if not hasattr(self._compiled_callable, "aot_compile"):
|
|
raise RuntimeError(
|
|
"aot_compile is not supported by the current configuration. "
|
|
"Please make sure torch.compile is enabled with the latest "
|
|
f"version of PyTorch (current using torch: {torch.__version__})"
|
|
)
|
|
return self._compiled_callable.aot_compile((args, kwargs))
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
|
if envs.VLLM_USE_BYTECODE_HOOK:
|
|
if (
|
|
self.vllm_config.compilation_config.mode
|
|
== CompilationMode.STOCK_TORCH_COMPILE
|
|
):
|
|
return self._compiled_callable(*args, **kwargs)
|
|
|
|
if not self._compiled_bytecode:
|
|
# Make sure a compilation is triggered by clearing dynamo
|
|
# cache.
|
|
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
|
return self._call_with_optional_nvtx_range(
|
|
self._compiled_callable, *args, **kwargs
|
|
)
|
|
else:
|
|
with self._dispatch_to_compiled_code():
|
|
return self._call_with_optional_nvtx_range(
|
|
self.forward, *args, **kwargs
|
|
)
|
|
else:
|
|
ctx = (
|
|
nullcontext()
|
|
if self.first_compile or not self.evaluate_guards
|
|
else torch.compiler.set_stance("fail_on_recompile")
|
|
)
|
|
self.first_compile = False
|
|
with _compilation_context(), ctx:
|
|
return self._call_with_optional_nvtx_range(
|
|
self._compiled_callable, *args, **kwargs
|
|
)
|
|
|
|
@abstractmethod
|
|
def forward(self, *args: Any, **kwargs: Any) -> Any: ...
|
|
|
|
def original_code_object(self) -> CodeType:
|
|
"""Return the original code object of the forward method."""
|
|
return self.__class__.forward.__code__
|
|
|
|
def bytecode_hook(self, old_code: CodeType, new_code: CodeType) -> None:
|
|
"""Hook to save the compiled bytecode for direct execution."""
|
|
if old_code is not self.original_code_object():
|
|
return
|
|
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
|
frame = sys._getframe()
|
|
while frame and frame.f_back:
|
|
frame = frame.f_back
|
|
code_name = frame.f_code.co_name
|
|
file_name = frame.f_code.co_filename.split(os.path.sep)[-1]
|
|
if code_name == "_compile" and file_name == "convert_frame.py":
|
|
break
|
|
frame = frame.f_locals["frame"]
|
|
assert frame.f_code == old_code
|
|
|
|
if frame.f_locals["self"] is not self:
|
|
return
|
|
|
|
self._compiled_bytecode = new_code
|
|
|
|
path = self.vllm_config.compile_debug_dump_path()
|
|
if path:
|
|
decompiled_file = path / "transformed_code.py"
|
|
if not decompiled_file.exists():
|
|
try:
|
|
# usually the decompilation will succeed for most models,
|
|
# as we guarantee a full-graph compilation in Dynamo.
|
|
# but there's no 100% guarantee, since decompliation is
|
|
# not a reversible process.
|
|
import depyf
|
|
|
|
src = depyf.decompile(new_code)
|
|
|
|
with open(decompiled_file, "w") as f:
|
|
f.write(src)
|
|
|
|
logger.debug("Dynamo transformed code saved to %s", decompiled_file)
|
|
except Exception:
|
|
pass
|
|
|
|
if (
|
|
self.vllm_config.compilation_config.cudagraph_mode != CUDAGraphMode.NONE
|
|
and "update" in new_code.co_names
|
|
):
|
|
import depyf
|
|
|
|
src = depyf.decompile(new_code)
|
|
msg = (
|
|
"Assigning / modifying buffers of nn.Module during forward pass is not "
|
|
"allowed when using cudagraph inside the compiler because it will "
|
|
"cause silent errors. Please use eager mode or fix the code. The "
|
|
"following code contains clues about which buffer is being modified "
|
|
f"(please search for the usage of the function `update`):\n{src}"
|
|
)
|
|
raise RuntimeError(msg)
|
|
|
|
@contextmanager
|
|
def _dispatch_to_compiled_code(self) -> Generator[None, None, None]:
|
|
# noqa: E501
|
|
"""
|
|
Context manager to dispatch to internally compiled code for torch<2.8.
|
|
Why does this work? Because Dynamo guarantees that the compiled
|
|
bytecode has exactly the same arguments, cell variables, and free
|
|
variables as the original code. Therefore we can directly switch
|
|
the code object in the function and call it.
|
|
|
|
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
|
""" # noqa: E501 line too long
|
|
original = self.original_code_object()
|
|
assert self._compiled_bytecode is not None
|
|
self.__class__.forward.__code__ = self._compiled_bytecode
|
|
try:
|
|
yield
|
|
finally:
|
|
self.__class__.forward.__code__ = original
|
|
|
|
|
|
def reset_compile_wrapper(model: torch.nn.Module) -> None:
|
|
"""
|
|
Clean up compiled model and captured CUDA graphs for elastic EP.
|
|
"""
|
|
if not isinstance(model, TorchCompileWithNoGuardsWrapper) and hasattr(
|
|
model, "model"
|
|
):
|
|
model = model.model
|
|
if not isinstance(model, TorchCompileWithNoGuardsWrapper):
|
|
return
|
|
# model.do_not_compile is set by the @support_torch_compile decorator
|
|
if hasattr(model, "do_not_compile") and model.do_not_compile:
|
|
return
|
|
from vllm.compilation.counter import compilation_counter
|
|
|
|
# reset the compilation counter
|
|
compilation_counter.num_models_seen = 0
|
|
compilation_counter.num_graphs_seen = 0
|
|
compilation_counter.num_piecewise_graphs_seen = 0
|
|
compilation_counter.num_piecewise_capturable_graphs_seen = 0
|
|
compilation_counter.num_backend_compilations = 0
|
|
compilation_counter.num_gpu_runner_capture_triggers = 0
|
|
compilation_counter.num_cudagraph_captured = 0
|
|
compilation_counter.num_inductor_compiles = 0
|
|
compilation_counter.num_eager_compiles = 0
|
|
compilation_counter.num_cache_entries_updated = 0
|
|
compilation_counter.num_compiled_artifacts_saved = 0
|
|
compilation_counter.stock_torch_compile_count = 0
|
|
compilation_counter.num_aot_compiles = 0
|
|
compilation_counter.num_aot_artifacts_saved = 0
|
|
compilation_counter.num_aot_artifacts_loaded = 0
|
|
|
|
# Clear the AOT compiled function so the model is forced to
|
|
# recompile on the next call. Without this, decorators.py
|
|
# __call__ uses the stale aot_compiled_fn whose torchinductor
|
|
# kernels have old parameters (expert_map size for example)
|
|
# baked in as compile-time constants.
|
|
if hasattr(model, "aot_compiled_fn"):
|
|
model.aot_compiled_fn = None
|
|
if hasattr(model, "was_aot_compile_fn_loaded_from_disk"):
|
|
model.was_aot_compile_fn_loaded_from_disk = False
|
|
|
|
# Reset the cache_dir so VllmBackend recomputes the hash
|
|
# (data_parallel_size changed, so the config hash differs).
|
|
compilation_config = model.vllm_config.compilation_config
|
|
compilation_config.cache_dir = ""
|
|
compilation_config.local_cache_dir = ""
|
|
|
|
model.__class__.forward.__code__ = model.original_code_object()
|
|
TorchCompileWithNoGuardsWrapper.__init__(
|
|
model,
|
|
compile_prefix=model._compile_prefix,
|
|
is_encoder=model._is_encoder,
|
|
)
|