[torch.compile] Allow usage of Opaque Objects in PyTorch 2.11 (#39286)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-04-09 01:21:10 +02:00
committed by GitHub
parent f3c7941ec8
commit ba4a78eb5d
5 changed files with 69 additions and 45 deletions

View File

@@ -16,6 +16,7 @@ import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.config import VllmConfig
from vllm.config.utils import Range
from vllm.env_override import _apply_constrain_to_fx_strides_patch
from vllm.logger import init_logger
from vllm.utils.hashing import safe_hash
from vllm.utils.torch_utils import is_torch_equal_or_newer
@@ -225,48 +226,6 @@ def _patch_standalone_compile_atomic_save() -> None:
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
def _patch_constrain_to_fx_strides() -> contextlib.AbstractContextManager:
"""Context manager that patches inductor's ``constrain_to_fx_strides``
to handle opaque (non-tensor) arguments.
The original calls ``.stride()`` on every FX arg's meta value, which
crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted
opaque types). The patched version skips args whose meta value is
not a ``torch.Tensor``.
Returns ``nullcontext`` on torch < 2.11.
Upstream issue: https://github.com/pytorch/pytorch/issues/175973
"""
if not is_torch_equal_or_newer("2.11.0.dev"):
return contextlib.nullcontext()
import torch._inductor.ir as _ir
import torch._inductor.lowering as _lowering
from torch._inductor.virtualized import V as _V
def _patched(fx_node, *args, **kwargs):
def apply_constraint(arg, fx_arg):
if isinstance(arg, _ir.IRNode):
meta_val = fx_arg.meta.get("val")
if isinstance(meta_val, torch.Tensor):
stride_order = _ir.get_stride_order(
meta_val.stride(), _V.graph.sizevars.shape_env
)
return _ir.ExternKernel.require_stride_order(arg, stride_order)
return arg
if isinstance(arg, dict):
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
return arg
args = tuple(
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
)
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
return args, kwargs
return patch.object(_lowering, "constrain_to_fx_strides", _patched)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
@@ -304,6 +263,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
_apply_constrain_to_fx_strides_patch()
compilation_counter.num_inductor_compiles += 1
current_config = {}
if compiler_config is not None:
@@ -387,7 +347,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
else:
fake_mode_ctx = contextlib.nullcontext()
with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
with pregrad_ctx, fake_mode_ctx:
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
@@ -502,6 +462,7 @@ class InductorAdaptor(CompilerInterface):
compile_range: Range,
key: str | None = None,
) -> tuple[Callable[..., Any] | None, Any | None]:
_apply_constrain_to_fx_strides_patch()
compilation_counter.num_inductor_compiles += 1
from torch._inductor.compile_fx import compile_fx
@@ -630,7 +591,6 @@ class InductorAdaptor(CompilerInterface):
stack.enter_context(
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
stack.enter_context(_patch_constrain_to_fx_strides())
# Clear the tracing context before calling compile_fx.
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter

View File

@@ -143,6 +143,13 @@ class TorchCompileWithNoGuardsWrapper:
compiled_ptr = self.check_invariants_and_forward
# Apply the constrain_to_fx_strides patch before first compilation.
# This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The VLLM
# compile paths call this from their own compile() methods too.
from vllm.env_override import _apply_constrain_to_fx_strides_patch
_apply_constrain_to_fx_strides_patch()
aot_context = nullcontext()
if envs.VLLM_USE_AOT_COMPILE:
if hasattr(torch._dynamo.config, "enable_aot_compile"):

View File

@@ -500,6 +500,60 @@ if is_torch_equal("2.9.0"):
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
# and can be removed once torch >=2.12 is the minimum supported version.
# ===================================================
# torch >= 2.11 Inductor constrain_to_fx_strides monkeypatch
# ===================================================
# Inductor's constrain_to_fx_strides calls .stride() on every FX arg's meta
# value, which crashes on FakeScriptObject (the compile-time proxy for hoisted
# opaque types). The patched version skips args whose meta value is not a
# torch.Tensor.
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
_constrain_to_fx_strides_patched = False
def _apply_constrain_to_fx_strides_patch():
"""Patch lowering.constrain_to_fx_strides globally. Safe to call
multiple times; only the first call does anything.
Only applies for torch >= 2.11 and < 2.12."""
global _constrain_to_fx_strides_patched
if _constrain_to_fx_strides_patched:
return
_constrain_to_fx_strides_patched = True
if not is_torch_equal_or_newer("2.11.0.dev") or is_torch_equal_or_newer(
"2.12.0.dev"
):
return
import torch._inductor.ir as _ir
import torch._inductor.lowering as _lowering
from torch._inductor.virtualized import V as _V
def _patched(fx_node, *args, **kwargs):
def apply_constraint(arg, fx_arg):
if isinstance(arg, _ir.IRNode):
meta_val = fx_arg.meta.get("val")
if isinstance(meta_val, torch.Tensor):
stride_order = _ir.get_stride_order(
meta_val.stride(), _V.graph.sizevars.shape_env
)
return _ir.ExternKernel.require_stride_order(arg, stride_order)
return arg
if isinstance(arg, dict):
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
return arg
args = tuple(
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
)
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
return args, kwargs
_lowering.constrain_to_fx_strides = _patched
if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0"):
import builtins as _builtins
import pickle

View File

@@ -706,7 +706,7 @@ def is_torch_equal(target: str) -> bool:
return Version(importlib.metadata.version("torch")) == Version(target)
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.12.0.dev")
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")
if HAS_OPAQUE_TYPE:
from torch._opaque_base import OpaqueBase

View File

@@ -4857,6 +4857,9 @@ class GPUModelRunner(
self.vllm_config.compilation_config.mode
== CompilationMode.STOCK_TORCH_COMPILE
):
from vllm.env_override import _apply_constrain_to_fx_strides_patch
_apply_constrain_to_fx_strides_patch()
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
compilation_counter.stock_torch_compile_count += 1
self.model.compile(fullgraph=True, backend=backend)