Add option to use unbacked, and backed size obl dynamic shapes for more sounds compilation. (#26199)
Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
@@ -24,6 +24,7 @@ from vllm.config import (
|
||||
get_current_vllm_config,
|
||||
set_current_vllm_config,
|
||||
)
|
||||
from vllm.config.compilation import DynamicShapesType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
@@ -104,6 +105,7 @@ def support_torch_compile(
|
||||
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> Callable[[_T], _T] | _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
@@ -161,6 +163,14 @@ def support_torch_compile(
|
||||
dim to be decorated with `mark_unbacked`. This is useful if we would like to
|
||||
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
|
||||
such as for vision model compilation
|
||||
|
||||
`shape_invariants` is a function that gets compiled right before forward.
|
||||
The function should have the torch._check calls that are needed to set
|
||||
the relationships between different input sizes. For example:
|
||||
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
|
||||
This enforces constraints on the symbolic shapes without hardcoding
|
||||
specific values. It is needed for some models to avoid data dependent
|
||||
errors.
|
||||
"""
|
||||
|
||||
def cls_decorator_helper(cls: _T) -> _T:
|
||||
@@ -199,7 +209,11 @@ def support_torch_compile(
|
||||
f"Argument {k} not found in the forward method of {cls}"
|
||||
)
|
||||
return _support_torch_compile(
|
||||
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
|
||||
cls,
|
||||
inferred_dynamic_arg_dims,
|
||||
mark_unbacked_dims,
|
||||
enable_if,
|
||||
shape_invariants,
|
||||
)
|
||||
|
||||
if cls is not None:
|
||||
@@ -242,6 +256,7 @@ def _support_torch_compile(
|
||||
dynamic_arg_dims: dict[str, int | list[int]],
|
||||
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
|
||||
enable_if: Callable[[VllmConfig], bool] | None = None,
|
||||
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
|
||||
) -> _T:
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
@@ -276,11 +291,12 @@ def _support_torch_compile(
|
||||
old_init(self, **kwargs)
|
||||
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = self.vllm_config.compilation_config
|
||||
enable_compile = enable_if is None or enable_if(vllm_config)
|
||||
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
self.do_not_compile = (
|
||||
vllm_config.compilation_config.mode
|
||||
self.compilation_config.mode
|
||||
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
|
||||
or not supports_dynamo()
|
||||
or _should_ignore_torch_compile(self.__class__)
|
||||
@@ -289,29 +305,38 @@ def _support_torch_compile(
|
||||
if self.do_not_compile:
|
||||
return
|
||||
|
||||
self._check_shape_invariants = shape_invariants
|
||||
|
||||
compilation_counter.num_models_seen += 1
|
||||
self.compiled = False
|
||||
TorchCompileWithNoGuardsWrapper.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def _mark_dynamic_inputs(mod, *args, **kwargs):
|
||||
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
|
||||
def mark_dynamic(arg, dims):
|
||||
if type == DynamicShapesType.UNBACKED:
|
||||
torch._dynamo.decorators.mark_unbacked(arg, dims)
|
||||
else:
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
|
||||
sig = inspect.signature(mod.__class__.forward)
|
||||
bound_args = sig.bind(mod, *args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
for k, dims in dynamic_arg_dims.items():
|
||||
arg = bound_args.arguments.get(k)
|
||||
|
||||
if arg is not None:
|
||||
dims = [dims] if isinstance(dims, int) else dims
|
||||
if isinstance(arg, torch.Tensor):
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(arg, dims)
|
||||
mark_dynamic(arg, dims)
|
||||
elif isinstance(arg, IntermediateTensors):
|
||||
for tensor in arg.tensors.values():
|
||||
# In case dims is specified with negative indexing
|
||||
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
|
||||
torch._dynamo.mark_dynamic(tensor, dims)
|
||||
mark_dynamic(tensor, dims)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported dynamic dimensions"
|
||||
@@ -338,6 +363,7 @@ def _support_torch_compile(
|
||||
if getattr(self, "aot_compiled_fn", None) is not None:
|
||||
return self.aot_compiled_fn(self, *args, **kwargs)
|
||||
|
||||
ds_type = self.compilation_config.dynamic_shapes_config.type
|
||||
cache_dir = None
|
||||
aot_compilation_path = None
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
@@ -352,6 +378,14 @@ def _support_torch_compile(
|
||||
serialized backend artifacts), then we need to generate a new AOT
|
||||
compile artifact from scratch.
|
||||
"""
|
||||
# Validate that AOT compile is not used with unbacked dynamic
|
||||
# shapes. aot_compile re-allocates backed symbols post dynamo!
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
raise ValueError(
|
||||
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
|
||||
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
|
||||
"when VLLM_USE_AOT_COMPILE is enabled."
|
||||
)
|
||||
from .caching import compilation_config_hash_factors
|
||||
|
||||
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
|
||||
@@ -401,7 +435,12 @@ def _support_torch_compile(
|
||||
# This is the path for the first compilation.
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
_mark_dynamic_inputs(self, *args, **kwargs)
|
||||
_mark_dynamic_inputs(
|
||||
self,
|
||||
ds_type,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# here, it is the starting point of the `torch.compile` process
|
||||
start_monitoring_torch_compile(self.vllm_config)
|
||||
@@ -417,9 +456,7 @@ def _support_torch_compile(
|
||||
# properly when any of these files change.
|
||||
|
||||
# 1. the file containing the top-level forward function
|
||||
self.vllm_config.compilation_config.traced_files.add(
|
||||
original_code_object.co_filename
|
||||
)
|
||||
self.compilation_config.traced_files.add(original_code_object.co_filename)
|
||||
|
||||
# 2. every time Dynamo sees a function call, it will inline
|
||||
# the function by calling InliningInstructionTranslator.inline_call_
|
||||
@@ -429,7 +466,7 @@ def _support_torch_compile(
|
||||
|
||||
def patched_inline_call(self_):
|
||||
code = self_.f_code
|
||||
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
|
||||
self.compilation_config.traced_files.add(code.co_filename)
|
||||
return inline_call(self_)
|
||||
|
||||
# Disable the C++ compilation of symbolic shape guards. C++-fication
|
||||
@@ -445,12 +482,18 @@ def _support_torch_compile(
|
||||
# if the config doesn't exist
|
||||
logger.debug("enable_cpp_symbolic_shape_guards config not available")
|
||||
|
||||
# Prepare backed_size_oblivious config patch if needed
|
||||
fx_config_patches = {}
|
||||
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
|
||||
fx_config_patches["backed_size_oblivious"] = True
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
InliningInstructionTranslator, "inline_call_", patched_inline_call
|
||||
),
|
||||
torch._dynamo.config.patch(**dynamo_config_patches),
|
||||
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
|
||||
torch.fx.experimental._config.patch(**fx_config_patches),
|
||||
_torch27_patch_tensor_subclasses(),
|
||||
):
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
|
||||
@@ -6,6 +6,7 @@ import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._C._dynamo.guards
|
||||
@@ -85,6 +86,12 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
since we drop all guards.
|
||||
"""
|
||||
|
||||
def check_invariants_and_forward(self, *args, **kwargs):
|
||||
assert hasattr(self, "_check_shape_invariants")
|
||||
self._check_shape_invariants(*args, **kwargs)
|
||||
|
||||
return self.forward(*args, **kwargs)
|
||||
|
||||
def __init__(self):
|
||||
self.compiled = False
|
||||
|
||||
@@ -104,6 +111,21 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
# Drop all the guards.
|
||||
options["guard_filter_fn"] = lambda x: [False for _ in x]
|
||||
|
||||
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
|
||||
from vllm.compilation.decorators import DynamicShapesType
|
||||
|
||||
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
|
||||
compiled_ptr: Any = self.forward
|
||||
if ds_type == DynamicShapesType.UNBACKED:
|
||||
if envs.VLLM_USE_BYTECODE_HOOK:
|
||||
# reason is that bytecode does this hack torch._dynamo.eval_frame.
|
||||
# remove_from_cache(self.original_code_object()) to force a new
|
||||
# re-compilation.
|
||||
raise ValueError(
|
||||
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
|
||||
)
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
torch._dynamo.config.enable_aot_compile = True
|
||||
@@ -114,7 +136,7 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
logger.warning(msg)
|
||||
|
||||
self._compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
compiled_ptr,
|
||||
fullgraph=True,
|
||||
dynamic=False,
|
||||
backend=backend,
|
||||
|
||||
Reference in New Issue
Block a user