From ba4a78eb5d2ea30477b58a0bb8109b129f35c8b1 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Thu, 9 Apr 2026 01:21:10 +0200 Subject: [PATCH] [torch.compile] Allow usage of Opaque Objects in PyTorch 2.11 (#39286) Signed-off-by: Richard Zou --- vllm/compilation/compiler_interface.py | 48 ++--------------------- vllm/compilation/wrapper.py | 7 ++++ vllm/env_override.py | 54 ++++++++++++++++++++++++++ vllm/utils/torch_utils.py | 2 +- vllm/v1/worker/gpu_model_runner.py | 3 ++ 5 files changed, 69 insertions(+), 45 deletions(-) diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index 5c34d3a1b..79339ebaf 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -16,6 +16,7 @@ import vllm.envs as envs from vllm.compilation.counter import compilation_counter from vllm.config import VllmConfig from vllm.config.utils import Range +from vllm.env_override import _apply_constrain_to_fx_strides_patch from vllm.logger import init_logger from vllm.utils.hashing import safe_hash from vllm.utils.torch_utils import is_torch_equal_or_newer @@ -225,48 +226,6 @@ def _patch_standalone_compile_atomic_save() -> None: logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__) -def _patch_constrain_to_fx_strides() -> contextlib.AbstractContextManager: - """Context manager that patches inductor's ``constrain_to_fx_strides`` - to handle opaque (non-tensor) arguments. - - The original calls ``.stride()`` on every FX arg's meta value, which - crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted - opaque types). The patched version skips args whose meta value is - not a ``torch.Tensor``. - - Returns ``nullcontext`` on torch < 2.11. - Upstream issue: https://github.com/pytorch/pytorch/issues/175973 - """ - if not is_torch_equal_or_newer("2.11.0.dev"): - return contextlib.nullcontext() - - import torch._inductor.ir as _ir - import torch._inductor.lowering as _lowering - from torch._inductor.virtualized import V as _V - - def _patched(fx_node, *args, **kwargs): - def apply_constraint(arg, fx_arg): - if isinstance(arg, _ir.IRNode): - meta_val = fx_arg.meta.get("val") - if isinstance(meta_val, torch.Tensor): - stride_order = _ir.get_stride_order( - meta_val.stride(), _V.graph.sizevars.shape_env - ) - return _ir.ExternKernel.require_stride_order(arg, stride_order) - return arg - if isinstance(arg, dict): - return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} - return arg - - args = tuple( - apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) - ) - kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} - return args, kwargs - - return patch.object(_lowering, "constrain_to_fx_strides", _patched) - - class InductorStandaloneAdaptor(CompilerInterface): """ The adaptor for the Inductor compiler. @@ -304,6 +263,7 @@ class InductorStandaloneAdaptor(CompilerInterface): compile_range: Range, key: str | None = None, ) -> tuple[Callable[..., Any] | None, Any | None]: + _apply_constrain_to_fx_strides_patch() compilation_counter.num_inductor_compiles += 1 current_config = {} if compiler_config is not None: @@ -387,7 +347,7 @@ class InductorStandaloneAdaptor(CompilerInterface): else: fake_mode_ctx = contextlib.nullcontext() - with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides(): + with pregrad_ctx, fake_mode_ctx: compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs) if use_aot: @@ -502,6 +462,7 @@ class InductorAdaptor(CompilerInterface): compile_range: Range, key: str | None = None, ) -> tuple[Callable[..., Any] | None, Any | None]: + _apply_constrain_to_fx_strides_patch() compilation_counter.num_inductor_compiles += 1 from torch._inductor.compile_fx import compile_fx @@ -630,7 +591,6 @@ class InductorAdaptor(CompilerInterface): stack.enter_context( torch._functorch.config.patch(enable_remote_autograd_cache=False) ) - stack.enter_context(_patch_constrain_to_fx_strides()) # Clear the tracing context before calling compile_fx. # vLLM calls compile_fx from within a PiecewiseCompileInterpreter diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index d5eb35e21..dc48528e9 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -143,6 +143,13 @@ class TorchCompileWithNoGuardsWrapper: compiled_ptr = self.check_invariants_and_forward + # Apply the constrain_to_fx_strides patch before first compilation. + # This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The VLLM + # compile paths call this from their own compile() methods too. + from vllm.env_override import _apply_constrain_to_fx_strides_patch + + _apply_constrain_to_fx_strides_patch() + aot_context = nullcontext() if envs.VLLM_USE_AOT_COMPILE: if hasattr(torch._dynamo.config, "enable_aot_compile"): diff --git a/vllm/env_override.py b/vllm/env_override.py index aa09c4a9d..55dd5099d 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -500,6 +500,60 @@ if is_torch_equal("2.9.0"): # This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558 # and can be removed once torch >=2.12 is the minimum supported version. +# =================================================== +# torch >= 2.11 Inductor constrain_to_fx_strides monkeypatch +# =================================================== +# Inductor's constrain_to_fx_strides calls .stride() on every FX arg's meta +# value, which crashes on FakeScriptObject (the compile-time proxy for hoisted +# opaque types). The patched version skips args whose meta value is not a +# torch.Tensor. +# Upstream issue: https://github.com/pytorch/pytorch/issues/175973 + + +_constrain_to_fx_strides_patched = False + + +def _apply_constrain_to_fx_strides_patch(): + """Patch lowering.constrain_to_fx_strides globally. Safe to call + multiple times; only the first call does anything. + Only applies for torch >= 2.11 and < 2.12.""" + global _constrain_to_fx_strides_patched + if _constrain_to_fx_strides_patched: + return + _constrain_to_fx_strides_patched = True + + if not is_torch_equal_or_newer("2.11.0.dev") or is_torch_equal_or_newer( + "2.12.0.dev" + ): + return + + import torch._inductor.ir as _ir + import torch._inductor.lowering as _lowering + from torch._inductor.virtualized import V as _V + + def _patched(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, _ir.IRNode): + meta_val = fx_arg.meta.get("val") + if isinstance(meta_val, torch.Tensor): + stride_order = _ir.get_stride_order( + meta_val.stride(), _V.graph.sizevars.shape_env + ) + return _ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + _lowering.constrain_to_fx_strides = _patched + + if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0"): import builtins as _builtins import pickle diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index 150c9cba5..94f8c096e 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -706,7 +706,7 @@ def is_torch_equal(target: str) -> bool: return Version(importlib.metadata.version("torch")) == Version(target) -HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.12.0.dev") +HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev") if HAS_OPAQUE_TYPE: from torch._opaque_base import OpaqueBase diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9405b5f72..a9a4497a3 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -4857,6 +4857,9 @@ class GPUModelRunner( self.vllm_config.compilation_config.mode == CompilationMode.STOCK_TORCH_COMPILE ): + from vllm.env_override import _apply_constrain_to_fx_strides_patch + + _apply_constrain_to_fx_strides_patch() backend = self.vllm_config.compilation_config.init_backend(self.vllm_config) compilation_counter.stock_torch_compile_count += 1 self.model.compile(fullgraph=True, backend=backend)