Avoid bytecode hook and simplify TorchCompileWrapperWithCustomDipatch (#25110)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -17,7 +17,7 @@ from torch._dynamo.symbolic_convert import InliningInstructionTranslator
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.compilation.wrapper import TorchCompileWithNoGuardsWrapper
|
||||
from vllm.config import (
|
||||
CompilationMode,
|
||||
VllmConfig,
|
||||
@@ -246,14 +246,14 @@ def _support_torch_compile(
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
"""
|
||||
if TorchCompileWrapperWithCustomDispatcher in cls.__bases__:
|
||||
if TorchCompileWithNoGuardsWrapper in cls.__bases__:
|
||||
# support decorating multiple times
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher,)
|
||||
# other than TorchCompileWithNoGuardsWrapper
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWithNoGuardsWrapper,)
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
@@ -290,12 +290,43 @@ def _support_torch_compile(
|
||||
return
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(
|
||||
self, compilation_mode=vllm_config.compilation_config.mode
|
||||
)
|
||||
self.compiled = False
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(mod, *args, **kwargs):
|
||||
sig = inspect.signature(mod.__class__.forward)
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
@@ -303,6 +334,7 @@ def _support_torch_compile(
|
||||
if self.do_not_compile or torch.compiler.is_compiling():
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
# if aot_compiled_fn is set, just call it.
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
@@ -362,120 +394,84 @@ def _support_torch_compile(
|
||||
)
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
if self.compiled:
|
||||
assert not envs.VLLM_USE_AOT_COMPILE
|
||||
return TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
# This is the path for the first compilation.
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
if len(self.compiled_codes) < 1:
|
||||
sig = inspect.signature(self.__class__.forward)
|
||||
bound_args = sig.bind(self, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [
|
||||
tensor.ndim + dim if dim < 0 else dim for dim in dims
|
||||
]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
f" {dims} for argument {k} with type {type(arg)}."
|
||||
)
|
||||
if mark_unbacked_dims:
|
||||
for k, dims in mark_unbacked_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
logger.debug("Start compiling function %s", self.original_code_object)
|
||||
_mark_dynamic_inputs(self, *args, **kwargs)
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
# with the overhead of guard evaluation and recompilation.
|
||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||
# it seems Dynamo reuse the compilation across instances,
|
||||
# while we need to make sure the compiled code is not reused.
|
||||
# we need to control all the compilation of the model.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object)
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
original_code_object = self.original_code_object()
|
||||
logger.debug("Start compiling function %s", original_code_object)
|
||||
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
# we do not want tp delete the original code object entries since
|
||||
# we depend on them now to look up cached compiled functions.
|
||||
# torch._dynamo.eval_frame.remove_from_cache(original_code_object)
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
self.original_code_object.co_filename
|
||||
)
|
||||
# collect all relevant files traced by Dynamo,
|
||||
# so that the compilation cache can trigger re-compilation
|
||||
# properly when any of these files change.
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
original_code_object.co_filename
|
||||
)
|
||||
|
||||
def patched_inline_call(self_):
|
||||
code = self_.f_code
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
# we hijack this function to know all the functions called
|
||||
# during Dynamo tracing, and their corresponding files
|
||||
inline_call = InliningInstructionTranslator.inline_call_
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
def patched_inline_call(self_):
|
||||
code = self_.f_code
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
assert aot_compilation_path is not None
|
||||
assert cache_dir is not None
|
||||
try:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.aot_compiled_fn.save_compiled_function(
|
||||
aot_compilation_path
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot save aot compilation to path %s, error: %s",
|
||||
aot_compilation_path,
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
output = self.compiled_callable(*args, **kwargs)
|
||||
return output
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
# of symbolic shape guards can improve guard overhead. But, since
|
||||
# vllm skip guards anyways, setting this flag to False can improve
|
||||
# compile time.
|
||||
dynamo_config_patches = {}
|
||||
try:
|
||||
_ = torch._dynamo.config.enable_cpp_symbolic_shape_guards
|
||||
dynamo_config_patches["enable_cpp_symbolic_shape_guards"] = False
|
||||
except AttributeError:
|
||||
# Note: this config is not available in torch 2.6, we can skip
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
# the Dynamo guard mechanism.
|
||||
with self.dispatch_to_code(0):
|
||||
model_output = self.forward(*args, **kwargs)
|
||||
return model_output
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
|
||||
output = self.aot_compiled_fn(self, *args, **kwargs)
|
||||
assert aot_compilation_path is not None
|
||||
assert cache_dir is not None
|
||||
try:
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Cannot save aot compilation to path %s, error: %s",
|
||||
aot_compilation_path,
|
||||
str(e),
|
||||
)
|
||||
else:
|
||||
output = TorchCompileWithNoGuardsWrapper.__call__(self, *args, **kwargs)
|
||||
|
||||
self.compiled = True
|
||||
return output
|
||||
|
||||
cls.__call__ = __call__
|
||||
return cls
|
||||
|
||||
@@ -4,11 +4,11 @@
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationMode, CUDAGraphMode, get_current_vllm_config
|
||||
@@ -17,88 +17,153 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
def _noop_add_global_state_guard(self, *args, **kwargs):
|
||||
"""No-op to skip the GLOBAL_STATE guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
def _noop_add_torch_function_mode_stack_guard(self, *args, **kwargs):
|
||||
"""No-op to skip the TORCH_FUNCTION_MODE_STACK guard entirely"""
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _compilation_context():
|
||||
"""Context manager for compilation settings and patches.
|
||||
|
||||
This manager:
|
||||
1. Sets higher dynamo cache limits for compilation. (Needed for
|
||||
qwen2_5_vl see test_qwen2_5_vl_evs_functionality).
|
||||
Generally a recompilation can happen whenever we use a new
|
||||
backend instance in torch.compile.
|
||||
2. Patches out add_global_state_guard to skip GLOBAL_STATE guards
|
||||
3. Patches out add_torch_function_mode_stack_guard to skip
|
||||
TORCH_FUNCTION_MODE_STACK guards.
|
||||
4. Restores everything when compilation completes
|
||||
"""
|
||||
A wrapper class for torch.compile, with a custom dispatch logic.
|
||||
Subclasses should:
|
||||
1. Implement the forward method
|
||||
2. Implement the dispatch logic in the __call__ method
|
||||
It can use `self.compiled_codes` to access the compiled bytecode,
|
||||
and `with self.dispatch_to_code(index):` to dispatch to
|
||||
the compiled code.
|
||||
3. Implement the `__init__` method to determine how to call
|
||||
`torch.compile` over the forward method.
|
||||
# Save original values
|
||||
original_global_state_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard
|
||||
)
|
||||
original_torch_function_mode_stack_guard = (
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard
|
||||
)
|
||||
original_cache_size = torch._dynamo.config.cache_size_limit
|
||||
original_accumulated_cache = torch._dynamo.config.accumulated_cache_size_limit
|
||||
|
||||
try:
|
||||
# Set higher cache limits for compilation
|
||||
torch._dynamo.config.cache_size_limit = 2048
|
||||
torch._dynamo.config.accumulated_cache_size_limit = 8192
|
||||
|
||||
# Patch guard manager
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
_noop_add_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
_noop_add_torch_function_mode_stack_guard
|
||||
)
|
||||
yield
|
||||
finally:
|
||||
# Restore original values
|
||||
torch._C._dynamo.guards.GuardManager.add_global_state_guard = (
|
||||
original_global_state_guard
|
||||
)
|
||||
torch._C._dynamo.guards.GuardManager.add_torch_function_mode_stack_guard = (
|
||||
original_torch_function_mode_stack_guard
|
||||
)
|
||||
torch._dynamo.config.cache_size_limit = original_cache_size
|
||||
torch._dynamo.config.accumulated_cache_size_limit = original_accumulated_cache
|
||||
|
||||
|
||||
class TorchCompileWithNoGuardsWrapper:
|
||||
"""
|
||||
A wrapper class for torch.compile, it ensures that all guards are dropped
|
||||
when CompilationMode is not CompilationMode.STOCK_TORCH_COMPILE.
|
||||
When guards are dropped, the first time __call__ is invoked, a single
|
||||
compilation is triggered. Dynamo should never be traced again after that
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compiled_callable: Callable | None = None,
|
||||
compilation_mode: CompilationMode = CompilationMode.NONE,
|
||||
):
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.vllm_config = vllm_config
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
mode = vllm_config.compilation_config.mode
|
||||
if mode is None:
|
||||
raise RuntimeError("Compilation mode cannot be NO_COMPILATION")
|
||||
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = None
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = (
|
||||
get_current_vllm_config().compilation_config.inductor_compile_config
|
||||
)
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
options = options or {}
|
||||
# This effectively drop all the guards.
|
||||
# We need this because bytecode hook is not used any more to
|
||||
# drop guards in the AOT compile mode.
|
||||
options["guard_filter_fn"] = lambda guards: [False for _ in guards]
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
backend = vllm_config.compilation_config.init_backend(vllm_config)
|
||||
options = {}
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward, fullgraph=True, backend=backend, options=options
|
||||
)
|
||||
if isinstance(backend, str) and backend == "inductor":
|
||||
options = vllm_config.compilation_config.inductor_compile_config
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: list[CodeType] = []
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
if mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
# Drop all the guards.
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
# read the env var to determine whether to use the custom dispatcher
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = (
|
||||
compilation_mode >= CompilationMode.DYNAMO_TRACE_ONCE
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
else:
|
||||
msg = "torch._dynamo.config.enable_aot_compile is not "
|
||||
msg += "available. AOT compile is disabled and please "
|
||||
msg += "upgrade PyTorch version to use AOT compile."
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
options=options,
|
||||
)
|
||||
|
||||
if envs.VLLM_USE_BYTECODE_HOOK and mode != CompilationMode.STOCK_TORCH_COMPILE:
|
||||
torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook)
|
||||
self._compiled_bytecode = None
|
||||
|
||||
def aot_compile(self, *args, **kwargs):
|
||||
if not hasattr(self.compiled_callable, "aot_compile"):
|
||||
if not hasattr(self._compiled_callable, "aot_compile"):
|
||||
raise RuntimeError(
|
||||
"aot_compile is not supported by the current configuration. "
|
||||
+ "Please make sure torch.compile is enabled with the latest "
|
||||
+ f"version of PyTorch (current using torch: {torch.__version__})"
|
||||
)
|
||||
return self.compiled_callable.aot_compile((args, kwargs))
|
||||
return self._compiled_callable.aot_compile((args, kwargs))
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile mode.
|
||||
NOTE: this function can have additional arguments beyond the forward
|
||||
method, for directly dispatching to the compiled code.
|
||||
"""
|
||||
return self.compiled_callable(*args, **kwargs)
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
if (
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
if not self._compiled_bytecode:
|
||||
# Make sure a compilation is triggered by clearing dynamo
|
||||
# cache.
|
||||
torch._dynamo.eval_frame.remove_from_cache(self.original_code_object())
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
else:
|
||||
with self._dispatch_to_compiled_code():
|
||||
return self.forward(*args, **kwargs)
|
||||
else:
|
||||
with _compilation_context():
|
||||
return self._compiled_callable(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, *args, **kwargs): ...
|
||||
|
||||
def original_code_object(self) -> CodeType:
|
||||
"""Return the original code object of the forward method."""
|
||||
return self.__class__.forward.__code__
|
||||
|
||||
def bytecode_hook(self, old_code: CodeType, new_code: CodeType):
|
||||
"""Hook to save the compiled bytecode for direct execution."""
|
||||
if old_code is not self.original_code_object:
|
||||
if old_code is not self.original_code_object():
|
||||
return
|
||||
# code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25
|
||||
frame = sys._getframe()
|
||||
@@ -114,7 +179,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
if frame.f_locals["self"] is not self:
|
||||
return
|
||||
|
||||
self.compiled_codes.append(new_code)
|
||||
self._compiled_bytecode = new_code
|
||||
|
||||
path = self.vllm_config.compile_debug_dump_path()
|
||||
if path:
|
||||
@@ -153,16 +218,21 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
raise RuntimeError(msg)
|
||||
|
||||
@contextmanager
|
||||
def dispatch_to_code(self, index: int):
|
||||
"""Context manager to dispatch to the compiled code.
|
||||
def _dispatch_to_compiled_code(self):
|
||||
# noqa: E501
|
||||
"""
|
||||
Context manager to dispatch to internally compiled code for torch<2.8.
|
||||
Why does this work? Because Dynamo guarantees that the compiled
|
||||
bytecode has exactly the same arguments, cell variables, and free
|
||||
variables as the original code. Therefore we can directly switch
|
||||
the code object in the function and call it.
|
||||
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7
|
||||
for more details.
|
||||
"""
|
||||
self.__class__.forward.__code__ = self.compiled_codes[index]
|
||||
yield
|
||||
self.__class__.forward.__code__ = self.original_code_object
|
||||
See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details.
|
||||
""" # noqa: E501 line too long
|
||||
original = self.original_code_object()
|
||||
assert self._compiled_bytecode is not None
|
||||
self.__class__.forward.__code__ = self._compiled_bytecode
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.__class__.forward.__code__ = original
|
||||
|
||||
Reference in New Issue
Block a user