[torch.compile] Allow usage of Opaque Objects in PyTorch 2.11 (#39286)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import vllm.envs as envs
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.utils import Range
|
||||
from vllm.env_override import _apply_constrain_to_fx_strides_patch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@@ -225,48 +226,6 @@ def _patch_standalone_compile_atomic_save() -> None:
|
||||
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
|
||||
|
||||
|
||||
def _patch_constrain_to_fx_strides() -> contextlib.AbstractContextManager:
|
||||
"""Context manager that patches inductor's ``constrain_to_fx_strides``
|
||||
to handle opaque (non-tensor) arguments.
|
||||
|
||||
The original calls ``.stride()`` on every FX arg's meta value, which
|
||||
crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted
|
||||
opaque types). The patched version skips args whose meta value is
|
||||
not a ``torch.Tensor``.
|
||||
|
||||
Returns ``nullcontext`` on torch < 2.11.
|
||||
Upstream issue: https://github.com/pytorch/pytorch/issues/175973
|
||||
"""
|
||||
if not is_torch_equal_or_newer("2.11.0.dev"):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
import torch._inductor.ir as _ir
|
||||
import torch._inductor.lowering as _lowering
|
||||
from torch._inductor.virtualized import V as _V
|
||||
|
||||
def _patched(fx_node, *args, **kwargs):
|
||||
def apply_constraint(arg, fx_arg):
|
||||
if isinstance(arg, _ir.IRNode):
|
||||
meta_val = fx_arg.meta.get("val")
|
||||
if isinstance(meta_val, torch.Tensor):
|
||||
stride_order = _ir.get_stride_order(
|
||||
meta_val.stride(), _V.graph.sizevars.shape_env
|
||||
)
|
||||
return _ir.ExternKernel.require_stride_order(arg, stride_order)
|
||||
return arg
|
||||
if isinstance(arg, dict):
|
||||
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
|
||||
return arg
|
||||
|
||||
args = tuple(
|
||||
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
|
||||
)
|
||||
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
return patch.object(_lowering, "constrain_to_fx_strides", _patched)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
@@ -304,6 +263,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
_apply_constrain_to_fx_strides_patch()
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
@@ -387,7 +347,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
else:
|
||||
fake_mode_ctx = contextlib.nullcontext()
|
||||
|
||||
with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
|
||||
with pregrad_ctx, fake_mode_ctx:
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
@@ -502,6 +462,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
compile_range: Range,
|
||||
key: str | None = None,
|
||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||
_apply_constrain_to_fx_strides_patch()
|
||||
compilation_counter.num_inductor_compiles += 1
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
@@ -630,7 +591,6 @@ class InductorAdaptor(CompilerInterface):
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(_patch_constrain_to_fx_strides())
|
||||
|
||||
# Clear the tracing context before calling compile_fx.
|
||||
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter
|
||||
|
||||
@@ -143,6 +143,13 @@ class TorchCompileWithNoGuardsWrapper:
|
||||
|
||||
compiled_ptr = self.check_invariants_and_forward
|
||||
|
||||
# Apply the constrain_to_fx_strides patch before first compilation.
|
||||
# This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The VLLM
|
||||
# compile paths call this from their own compile() methods too.
|
||||
from vllm.env_override import _apply_constrain_to_fx_strides_patch
|
||||
|
||||
_apply_constrain_to_fx_strides_patch()
|
||||
|
||||
aot_context = nullcontext()
|
||||
if envs.VLLM_USE_AOT_COMPILE:
|
||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||
|
||||
@@ -500,6 +500,60 @@ if is_torch_equal("2.9.0"):
|
||||
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
|
||||
# and can be removed once torch >=2.12 is the minimum supported version.
|
||||
|
||||
# ===================================================
|
||||
# torch >= 2.11 Inductor constrain_to_fx_strides monkeypatch
|
||||
# ===================================================
|
||||
# Inductor's constrain_to_fx_strides calls .stride() on every FX arg's meta
|
||||
# value, which crashes on FakeScriptObject (the compile-time proxy for hoisted
|
||||
# opaque types). The patched version skips args whose meta value is not a
|
||||
# torch.Tensor.
|
||||
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
|
||||
|
||||
|
||||
_constrain_to_fx_strides_patched = False
|
||||
|
||||
|
||||
def _apply_constrain_to_fx_strides_patch():
|
||||
"""Patch lowering.constrain_to_fx_strides globally. Safe to call
|
||||
multiple times; only the first call does anything.
|
||||
Only applies for torch >= 2.11 and < 2.12."""
|
||||
global _constrain_to_fx_strides_patched
|
||||
if _constrain_to_fx_strides_patched:
|
||||
return
|
||||
_constrain_to_fx_strides_patched = True
|
||||
|
||||
if not is_torch_equal_or_newer("2.11.0.dev") or is_torch_equal_or_newer(
|
||||
"2.12.0.dev"
|
||||
):
|
||||
return
|
||||
|
||||
import torch._inductor.ir as _ir
|
||||
import torch._inductor.lowering as _lowering
|
||||
from torch._inductor.virtualized import V as _V
|
||||
|
||||
def _patched(fx_node, *args, **kwargs):
|
||||
def apply_constraint(arg, fx_arg):
|
||||
if isinstance(arg, _ir.IRNode):
|
||||
meta_val = fx_arg.meta.get("val")
|
||||
if isinstance(meta_val, torch.Tensor):
|
||||
stride_order = _ir.get_stride_order(
|
||||
meta_val.stride(), _V.graph.sizevars.shape_env
|
||||
)
|
||||
return _ir.ExternKernel.require_stride_order(arg, stride_order)
|
||||
return arg
|
||||
if isinstance(arg, dict):
|
||||
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
|
||||
return arg
|
||||
|
||||
args = tuple(
|
||||
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
|
||||
)
|
||||
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
_lowering.constrain_to_fx_strides = _patched
|
||||
|
||||
|
||||
if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0"):
|
||||
import builtins as _builtins
|
||||
import pickle
|
||||
|
||||
@@ -706,7 +706,7 @@ def is_torch_equal(target: str) -> bool:
|
||||
return Version(importlib.metadata.version("torch")) == Version(target)
|
||||
|
||||
|
||||
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.12.0.dev")
|
||||
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")
|
||||
|
||||
if HAS_OPAQUE_TYPE:
|
||||
from torch._opaque_base import OpaqueBase
|
||||
|
||||
@@ -4857,6 +4857,9 @@ class GPUModelRunner(
|
||||
self.vllm_config.compilation_config.mode
|
||||
== CompilationMode.STOCK_TORCH_COMPILE
|
||||
):
|
||||
from vllm.env_override import _apply_constrain_to_fx_strides_patch
|
||||
|
||||
_apply_constrain_to_fx_strides_patch()
|
||||
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
|
||||
compilation_counter.stock_torch_compile_count += 1
|
||||
self.model.compile(fullgraph=True, backend=backend)
|
||||
|
||||
Reference in New Issue
Block a user