[torch.compile] Allow usage of Opaque Objects in PyTorch 2.11 (#39286)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import vllm.envs as envs
|
|||||||
from vllm.compilation.counter import compilation_counter
|
from vllm.compilation.counter import compilation_counter
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.config.utils import Range
|
from vllm.config.utils import Range
|
||||||
|
from vllm.env_override import _apply_constrain_to_fx_strides_patch
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils.hashing import safe_hash
|
from vllm.utils.hashing import safe_hash
|
||||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
@@ -225,48 +226,6 @@ def _patch_standalone_compile_atomic_save() -> None:
|
|||||||
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
|
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
|
||||||
|
|
||||||
|
|
||||||
def _patch_constrain_to_fx_strides() -> contextlib.AbstractContextManager:
|
|
||||||
"""Context manager that patches inductor's ``constrain_to_fx_strides``
|
|
||||||
to handle opaque (non-tensor) arguments.
|
|
||||||
|
|
||||||
The original calls ``.stride()`` on every FX arg's meta value, which
|
|
||||||
crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted
|
|
||||||
opaque types). The patched version skips args whose meta value is
|
|
||||||
not a ``torch.Tensor``.
|
|
||||||
|
|
||||||
Returns ``nullcontext`` on torch < 2.11.
|
|
||||||
Upstream issue: https://github.com/pytorch/pytorch/issues/175973
|
|
||||||
"""
|
|
||||||
if not is_torch_equal_or_newer("2.11.0.dev"):
|
|
||||||
return contextlib.nullcontext()
|
|
||||||
|
|
||||||
import torch._inductor.ir as _ir
|
|
||||||
import torch._inductor.lowering as _lowering
|
|
||||||
from torch._inductor.virtualized import V as _V
|
|
||||||
|
|
||||||
def _patched(fx_node, *args, **kwargs):
|
|
||||||
def apply_constraint(arg, fx_arg):
|
|
||||||
if isinstance(arg, _ir.IRNode):
|
|
||||||
meta_val = fx_arg.meta.get("val")
|
|
||||||
if isinstance(meta_val, torch.Tensor):
|
|
||||||
stride_order = _ir.get_stride_order(
|
|
||||||
meta_val.stride(), _V.graph.sizevars.shape_env
|
|
||||||
)
|
|
||||||
return _ir.ExternKernel.require_stride_order(arg, stride_order)
|
|
||||||
return arg
|
|
||||||
if isinstance(arg, dict):
|
|
||||||
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
|
|
||||||
return arg
|
|
||||||
|
|
||||||
args = tuple(
|
|
||||||
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
|
|
||||||
)
|
|
||||||
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
|
|
||||||
return args, kwargs
|
|
||||||
|
|
||||||
return patch.object(_lowering, "constrain_to_fx_strides", _patched)
|
|
||||||
|
|
||||||
|
|
||||||
class InductorStandaloneAdaptor(CompilerInterface):
|
class InductorStandaloneAdaptor(CompilerInterface):
|
||||||
"""
|
"""
|
||||||
The adaptor for the Inductor compiler.
|
The adaptor for the Inductor compiler.
|
||||||
@@ -304,6 +263,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
compile_range: Range,
|
compile_range: Range,
|
||||||
key: str | None = None,
|
key: str | None = None,
|
||||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||||
|
_apply_constrain_to_fx_strides_patch()
|
||||||
compilation_counter.num_inductor_compiles += 1
|
compilation_counter.num_inductor_compiles += 1
|
||||||
current_config = {}
|
current_config = {}
|
||||||
if compiler_config is not None:
|
if compiler_config is not None:
|
||||||
@@ -387,7 +347,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
|||||||
else:
|
else:
|
||||||
fake_mode_ctx = contextlib.nullcontext()
|
fake_mode_ctx = contextlib.nullcontext()
|
||||||
|
|
||||||
with pregrad_ctx, fake_mode_ctx, _patch_constrain_to_fx_strides():
|
with pregrad_ctx, fake_mode_ctx:
|
||||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||||
|
|
||||||
if use_aot:
|
if use_aot:
|
||||||
@@ -502,6 +462,7 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
compile_range: Range,
|
compile_range: Range,
|
||||||
key: str | None = None,
|
key: str | None = None,
|
||||||
) -> tuple[Callable[..., Any] | None, Any | None]:
|
) -> tuple[Callable[..., Any] | None, Any | None]:
|
||||||
|
_apply_constrain_to_fx_strides_patch()
|
||||||
compilation_counter.num_inductor_compiles += 1
|
compilation_counter.num_inductor_compiles += 1
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
|
||||||
@@ -630,7 +591,6 @@ class InductorAdaptor(CompilerInterface):
|
|||||||
stack.enter_context(
|
stack.enter_context(
|
||||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||||
)
|
)
|
||||||
stack.enter_context(_patch_constrain_to_fx_strides())
|
|
||||||
|
|
||||||
# Clear the tracing context before calling compile_fx.
|
# Clear the tracing context before calling compile_fx.
|
||||||
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter
|
# vLLM calls compile_fx from within a PiecewiseCompileInterpreter
|
||||||
|
|||||||
@@ -143,6 +143,13 @@ class TorchCompileWithNoGuardsWrapper:
|
|||||||
|
|
||||||
compiled_ptr = self.check_invariants_and_forward
|
compiled_ptr = self.check_invariants_and_forward
|
||||||
|
|
||||||
|
# Apply the constrain_to_fx_strides patch before first compilation.
|
||||||
|
# This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The VLLM
|
||||||
|
# compile paths call this from their own compile() methods too.
|
||||||
|
from vllm.env_override import _apply_constrain_to_fx_strides_patch
|
||||||
|
|
||||||
|
_apply_constrain_to_fx_strides_patch()
|
||||||
|
|
||||||
aot_context = nullcontext()
|
aot_context = nullcontext()
|
||||||
if envs.VLLM_USE_AOT_COMPILE:
|
if envs.VLLM_USE_AOT_COMPILE:
|
||||||
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
if hasattr(torch._dynamo.config, "enable_aot_compile"):
|
||||||
|
|||||||
@@ -500,6 +500,60 @@ if is_torch_equal("2.9.0"):
|
|||||||
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
|
# This mirrors the fix in https://github.com/pytorch/pytorch/pull/177558
|
||||||
# and can be removed once torch >=2.12 is the minimum supported version.
|
# and can be removed once torch >=2.12 is the minimum supported version.
|
||||||
|
|
||||||
|
# ===================================================
|
||||||
|
# torch >= 2.11 Inductor constrain_to_fx_strides monkeypatch
|
||||||
|
# ===================================================
|
||||||
|
# Inductor's constrain_to_fx_strides calls .stride() on every FX arg's meta
|
||||||
|
# value, which crashes on FakeScriptObject (the compile-time proxy for hoisted
|
||||||
|
# opaque types). The patched version skips args whose meta value is not a
|
||||||
|
# torch.Tensor.
|
||||||
|
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
|
||||||
|
|
||||||
|
|
||||||
|
_constrain_to_fx_strides_patched = False
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_constrain_to_fx_strides_patch():
|
||||||
|
"""Patch lowering.constrain_to_fx_strides globally. Safe to call
|
||||||
|
multiple times; only the first call does anything.
|
||||||
|
Only applies for torch >= 2.11 and < 2.12."""
|
||||||
|
global _constrain_to_fx_strides_patched
|
||||||
|
if _constrain_to_fx_strides_patched:
|
||||||
|
return
|
||||||
|
_constrain_to_fx_strides_patched = True
|
||||||
|
|
||||||
|
if not is_torch_equal_or_newer("2.11.0.dev") or is_torch_equal_or_newer(
|
||||||
|
"2.12.0.dev"
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
import torch._inductor.ir as _ir
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
from torch._inductor.virtualized import V as _V
|
||||||
|
|
||||||
|
def _patched(fx_node, *args, **kwargs):
|
||||||
|
def apply_constraint(arg, fx_arg):
|
||||||
|
if isinstance(arg, _ir.IRNode):
|
||||||
|
meta_val = fx_arg.meta.get("val")
|
||||||
|
if isinstance(meta_val, torch.Tensor):
|
||||||
|
stride_order = _ir.get_stride_order(
|
||||||
|
meta_val.stride(), _V.graph.sizevars.shape_env
|
||||||
|
)
|
||||||
|
return _ir.ExternKernel.require_stride_order(arg, stride_order)
|
||||||
|
return arg
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
|
||||||
|
return arg
|
||||||
|
|
||||||
|
args = tuple(
|
||||||
|
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
|
||||||
|
)
|
||||||
|
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
_lowering.constrain_to_fx_strides = _patched
|
||||||
|
|
||||||
|
|
||||||
if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0"):
|
if is_torch_equal_or_newer("2.10.0") and not is_torch_equal_or_newer("2.12.0"):
|
||||||
import builtins as _builtins
|
import builtins as _builtins
|
||||||
import pickle
|
import pickle
|
||||||
|
|||||||
@@ -706,7 +706,7 @@ def is_torch_equal(target: str) -> bool:
|
|||||||
return Version(importlib.metadata.version("torch")) == Version(target)
|
return Version(importlib.metadata.version("torch")) == Version(target)
|
||||||
|
|
||||||
|
|
||||||
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.12.0.dev")
|
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")
|
||||||
|
|
||||||
if HAS_OPAQUE_TYPE:
|
if HAS_OPAQUE_TYPE:
|
||||||
from torch._opaque_base import OpaqueBase
|
from torch._opaque_base import OpaqueBase
|
||||||
|
|||||||
@@ -4857,6 +4857,9 @@ class GPUModelRunner(
|
|||||||
self.vllm_config.compilation_config.mode
|
self.vllm_config.compilation_config.mode
|
||||||
== CompilationMode.STOCK_TORCH_COMPILE
|
== CompilationMode.STOCK_TORCH_COMPILE
|
||||||
):
|
):
|
||||||
|
from vllm.env_override import _apply_constrain_to_fx_strides_patch
|
||||||
|
|
||||||
|
_apply_constrain_to_fx_strides_patch()
|
||||||
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
|
backend = self.vllm_config.compilation_config.init_backend(self.vllm_config)
|
||||||
compilation_counter.stock_torch_compile_count += 1
|
compilation_counter.stock_torch_compile_count += 1
|
||||||
self.model.compile(fullgraph=True, backend=backend)
|
self.model.compile(fullgraph=True, backend=backend)
|
||||||
|
|||||||
Reference in New Issue
Block a user