Add option to use unbacked, and backed size obl dynamic shapes for more sounds compilation. (#26199)

Signed-off-by: Laith Sakka <lsakka@meta.com>
This commit is contained in:
Laith Sakka
2025-11-24 07:12:41 -08:00
committed by GitHub
parent f716a15372
commit 7a228b5305
8 changed files with 442 additions and 15 deletions

View File

@@ -24,6 +24,7 @@ from vllm.config import (
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.config.compilation import DynamicShapesType
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils.import_utils import resolve_obj_by_qualname
@@ -104,6 +105,7 @@ def support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]] | None = None,
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> Callable[[_T], _T] | _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -161,6 +163,14 @@ def support_torch_compile(
dim to be decorated with `mark_unbacked`. This is useful if we would like to
enforce that dynamo does not specialize on 0/1 values in the case of dummy input
such as for vision model compilation
`shape_invariants` is a function that gets compiled right before forward.
The function should have the torch._check calls that are needed to set
the relationships between different input sizes. For example:
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
This enforces constraints on the symbolic shapes without hardcoding
specific values. It is needed for some models to avoid data dependent
errors.
"""
def cls_decorator_helper(cls: _T) -> _T:
@@ -199,7 +209,11 @@ def support_torch_compile(
f"Argument {k} not found in the forward method of {cls}"
)
return _support_torch_compile(
cls, inferred_dynamic_arg_dims, mark_unbacked_dims, enable_if
cls,
inferred_dynamic_arg_dims,
mark_unbacked_dims,
enable_if,
shape_invariants,
)
if cls is not None:
@@ -242,6 +256,7 @@ def _support_torch_compile(
dynamic_arg_dims: dict[str, int | list[int]],
mark_unbacked_dims: dict[str, int | list[int]] | None = None,
enable_if: Callable[[VllmConfig], bool] | None = None,
shape_invariants: Callable[..., None] = lambda *args, **kwargs: None,
) -> _T:
"""
A decorator to add support for compiling the forward method of a class.
@@ -276,11 +291,12 @@ def _support_torch_compile(
old_init(self, **kwargs)
self.vllm_config = vllm_config
self.compilation_config = self.vllm_config.compilation_config
enable_compile = enable_if is None or enable_if(vllm_config)
# for CompilationMode.STOCK_TORCH_COMPILE , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self.do_not_compile = (
vllm_config.compilation_config.mode
self.compilation_config.mode
in [CompilationMode.NONE, CompilationMode.STOCK_TORCH_COMPILE]
or not supports_dynamo()
or _should_ignore_torch_compile(self.__class__)
@@ -289,29 +305,38 @@ def _support_torch_compile(
if self.do_not_compile:
return
self._check_shape_invariants = shape_invariants
compilation_counter.num_models_seen += 1
self.compiled = False
TorchCompileWithNoGuardsWrapper.__init__(self)
cls.__init__ = __init__
def _mark_dynamic_inputs(mod, *args, **kwargs):
def _mark_dynamic_inputs(mod, type, *args, **kwargs):
def mark_dynamic(arg, dims):
if type == DynamicShapesType.UNBACKED:
torch._dynamo.decorators.mark_unbacked(arg, dims)
else:
torch._dynamo.mark_dynamic(arg, dims)
sig = inspect.signature(mod.__class__.forward)
bound_args = sig.bind(mod, *args, **kwargs)
bound_args.apply_defaults()
for k, dims in dynamic_arg_dims.items():
arg = bound_args.arguments.get(k)
if arg is not None:
dims = [dims] if isinstance(dims, int) else dims
if isinstance(arg, torch.Tensor):
# In case dims is specified with negative indexing
dims = [arg.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.mark_dynamic(arg, dims)
mark_dynamic(arg, dims)
elif isinstance(arg, IntermediateTensors):
for tensor in arg.tensors.values():
# In case dims is specified with negative indexing
dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims]
torch._dynamo.mark_dynamic(tensor, dims)
mark_dynamic(tensor, dims)
else:
raise ValueError(
"Unsupported dynamic dimensions"
@@ -338,6 +363,7 @@ def _support_torch_compile(
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
ds_type = self.compilation_config.dynamic_shapes_config.type
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
@@ -352,6 +378,14 @@ def _support_torch_compile(
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
# Validate that AOT compile is not used with unbacked dynamic
# shapes. aot_compile re-allocates backed symbols post dynamo!
if ds_type == DynamicShapesType.UNBACKED:
raise ValueError(
"AOT compilation is not compatible with UNBACKED dynamic shapes. "
"Please use BACKED or BACKED_SIZE_OBLIVIOUS dynamic shapes type "
"when VLLM_USE_AOT_COMPILE is enabled."
)
from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
@@ -401,7 +435,12 @@ def _support_torch_compile(
# This is the path for the first compilation.
# the first compilation needs to have dynamic shapes marked
_mark_dynamic_inputs(self, *args, **kwargs)
_mark_dynamic_inputs(
self,
ds_type,
*args,
**kwargs,
)
# here, it is the starting point of the `torch.compile` process
start_monitoring_torch_compile(self.vllm_config)
@@ -417,9 +456,7 @@ def _support_torch_compile(
# properly when any of these files change.
# 1. the file containing the top-level forward function
self.vllm_config.compilation_config.traced_files.add(
original_code_object.co_filename
)
self.compilation_config.traced_files.add(original_code_object.co_filename)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call_
@@ -429,7 +466,7 @@ def _support_torch_compile(
def patched_inline_call(self_):
code = self_.f_code
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
self.compilation_config.traced_files.add(code.co_filename)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
@@ -445,12 +482,18 @@ def _support_torch_compile(
# if the config doesn't exist
logger.debug("enable_cpp_symbolic_shape_guards config not available")
# Prepare backed_size_oblivious config patch if needed
fx_config_patches = {}
if ds_type == DynamicShapesType.BACKED_SIZE_OBLIVIOUS:
fx_config_patches["backed_size_oblivious"] = True
with (
patch.object(
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
torch.fx.experimental._config.patch(**fx_config_patches),
_torch27_patch_tensor_subclasses(),
):
if envs.VLLM_USE_AOT_COMPILE:

View File

@@ -6,6 +6,7 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Any
import torch
import torch._C._dynamo.guards
@@ -85,6 +86,12 @@ class TorchCompileWithNoGuardsWrapper:
since we drop all guards.
"""
def check_invariants_and_forward(self, *args, **kwargs):
assert hasattr(self, "_check_shape_invariants")
self._check_shape_invariants(*args, **kwargs)
return self.forward(*args, **kwargs)
def __init__(self):
self.compiled = False
@@ -104,6 +111,21 @@ class TorchCompileWithNoGuardsWrapper:
# Drop all the guards.
options["guard_filter_fn"] = lambda x: [False for _ in x]
# Validate that unbacked dynamic shapes require VLLM_USE_BYTECODE_HOOK=False
from vllm.compilation.decorators import DynamicShapesType
ds_type = vllm_config.compilation_config.dynamic_shapes_config.type
compiled_ptr: Any = self.forward
if ds_type == DynamicShapesType.UNBACKED:
if envs.VLLM_USE_BYTECODE_HOOK:
# reason is that bytecode does this hack torch._dynamo.eval_frame.
# remove_from_cache(self.original_code_object()) to force a new
# re-compilation.
raise ValueError(
"UNBACKED dynamic shapes require VLLM_USE_BYTECODE_HOOK=0. "
)
compiled_ptr = self.check_invariants_and_forward
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
@@ -114,7 +136,7 @@ class TorchCompileWithNoGuardsWrapper:
logger.warning(msg)
self._compiled_callable = torch.compile(
self.forward,
compiled_ptr,
fullgraph=True,
dynamic=False,
backend=backend,

View File

@@ -192,6 +192,54 @@ class PassConfig:
self.enable_qk_norm_rope_fusion = False
class DynamicShapesType(str, enum.Enum):
"""Types of dynamic shapes handling in torch.compile().
see Dynamic shapes and vllm guard dropping in torch_compile.md
for more details."""
BACKED = "backed"
"""Use backed dynamic shapes. torch.compile() guards on backed dynamic
shapes and may add guards. Symbols are specialized to 0, 1, or >=2 even
without encountering branching on those ranges."""
UNBACKED = "unbacked"
"""Use unbacked dynamic shapes. Guaranteed not to be guarded on and not
0/1 specialized, but may throw data dependent errors when branches require
their value without explicit unbacked handling."""
BACKED_SIZE_OBLIVIOUS = "backed_size_oblivious"
"""Experimental flag that treats backed symbols as unbacked when explicit
unbacked handling is defined."""
@config
@dataclass
class DynamicShapesConfig:
"""Configuration to control/debug torch compile dynamic shapes."""
type: DynamicShapesType = DynamicShapesType.BACKED
"""Controls the type of dynamic shapes handling to use with torch.compile().
- BACKED: Default PyTorch behavior with potential guards ignored.
- UNBACKED: No guards guaranteed (most sound) but may throw
data dependent errors.
- BACKED_SIZE_OBLIVIOUS: Experimental safer alternative to
backed/unbacked.
"""
# TODO add a debug mode to fail
def compute_hash(self) -> str:
"""
Provide a hash for DynamicShapesConfig
"""
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, {})
return hash_factors(factors)
@config
@dataclass
class CompilationConfig:
@@ -322,7 +370,7 @@ class CompilationConfig:
If empty list [], no ops are excluded (suitable for full cudagraphs)."""
compile_mm_encoder: bool = False
"""Whether or not to compile the multimodal encoder.
Currently, this only works for `Qwen2_5_vl` on selected platforms.
Currently, this only works for `Qwen2_5_vl` on selected platforms.
Disabled by default until more models are supported/tested to work."""
# Inductor capture
@@ -348,9 +396,11 @@ class CompilationConfig:
"""Sizes to compile for inductor. In addition
to integers, it also supports "cudagraph_capture_sizes" to
specify the sizes for cudagraph capture."""
inductor_compile_config: dict = field(default_factory=dict)
"""Additional configurations for inductor.
- None: use default configurations."""
inductor_passes: dict[str, str] = field(default_factory=dict)
"""Additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
@@ -460,8 +510,15 @@ class CompilationConfig:
max_num_seqs, and prevents capture of many large graphs (>512) that would
greatly increase startup time with limited performance benefit.
"""
dynamic_shapes_config: DynamicShapesConfig = field(
default_factory=DynamicShapesConfig
)
"""Configuration for dynamic shapes options"""
local_cache_dir: str = field(default=None, init=False) # type: ignore
"""local cache dir for each rank"""
bs_to_padded_graph_size: list[int] = field(
default=None, # type: ignore
init=False,
@@ -530,6 +587,7 @@ class CompilationConfig:
from vllm.config.utils import get_hash_factors, hash_factors
factors = get_hash_factors(self, ignored_factors)
factors["pass_config"] = self.pass_config.compute_hash()
return hash_factors(factors)

View File

@@ -354,7 +354,17 @@ class LlamaDecoderLayer(nn.Module):
return vllm_config.quant_config
@support_torch_compile
def llama_model_invariants(
input_ids, positions, intermediate_tensors=None, inputs_embeds=None
):
"""Shape invariants for Llama model compilation, those are translated to
runtime assertions for unbacked dynamic shapes and are compiled away for
backed"""
if input_ids is not None:
torch._check(positions.size()[0] == input_ids.size()[0])
@support_torch_compile(shape_invariants=llama_model_invariants)
class LlamaModel(nn.Module):
def __init__(
self,

View File

@@ -274,6 +274,38 @@ class Qwen2DecoderLayer(nn.Module):
return hidden_states, residual
def qwen_2_model_invariants(
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
):
"""Shape invariants for Qwen2Model Model, those are translated to
runtime assertions for unbacked dynamic shapes and are compiled away for
backed"""
# All these should be equal.
# input_ids.size()[0]
# positions.size()[-1]
# intermediate_tensors["hidden_states"].size()[0]
# inputs_embeds.size()[0]
torch._check(input_ids.size()[0] == positions.size()[-1])
if intermediate_tensors is not None:
torch._check(
input_ids.size()[0] == intermediate_tensors["hidden_states"].size()[0]
)
if inputs_embeds is not None:
torch._check(input_ids.size()[0] == inputs_embeds.size()[0])
# Hidden dimensions should match (hidden_size)
# intermediate_tensors["hidden_states"].size()[1]
# inputs_embeds.size()[1]
if inputs_embeds is not None and intermediate_tensors is not None:
torch._check(
inputs_embeds.size()[1] == intermediate_tensors["hidden_states"].size()[1]
)
@support_torch_compile(
dynamic_arg_dims={
"input_ids": 0,
@@ -282,7 +314,8 @@ class Qwen2DecoderLayer(nn.Module):
"positions": -1,
"intermediate_tensors": 0,
"inputs_embeds": 0,
}
},
shape_invariants=qwen_2_model_invariants,
)
class Qwen2Model(nn.Module):
def __init__(