[compile] Stop unconditionally patching constrain_to_fx_strides (#36152)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -225,6 +225,48 @@ def _patch_standalone_compile_atomic_save() -> None:
|
||||
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
|
||||
|
||||
|
||||
def _patch_constrain_to_fx_strides() -> contextlib.AbstractContextManager:
|
||||
"""Context manager that patches inductor's ``constrain_to_fx_strides``
|
||||
to handle opaque (non-tensor) arguments.
|
||||
|
||||
The original calls ``.stride()`` on every FX arg's meta value, which
|
||||
crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted
|
||||
opaque types). The patched version skips args whose meta value is
|
||||
not a ``torch.Tensor``.
|
||||
|
||||
Returns ``nullcontext`` on torch < 2.11.
|
||||
Upstream issue: https://github.com/pytorch/pytorch/issues/175973
|
||||
"""
|
||||
if not is_torch_equal_or_newer("2.11.0.dev"):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
import torch._inductor.ir as _ir
|
||||
import torch._inductor.lowering as _lowering
|
||||
from torch._inductor.virtualized import V as _V
|
||||
|
||||
def _patched(fx_node, *args, **kwargs):
|
||||
def apply_constraint(arg, fx_arg):
|
||||
if isinstance(arg, _ir.IRNode):
|
||||
meta_val = fx_arg.meta.get("val")
|
||||
if isinstance(meta_val, torch.Tensor):
|
||||
stride_order = _ir.get_stride_order(
|
||||
meta_val.stride(), _V.graph.sizevars.shape_env
|
||||
)
|
||||
return _ir.ExternKernel.require_stride_order(arg, stride_order)
|
||||
return arg
|
||||
if isinstance(arg, dict):
|
||||
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
|
||||
return arg
|
||||
|
||||
args = tuple(
|
||||
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
|
||||
)
|
||||
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
|
||||
return args, kwargs
|
||||
|
||||
return patch.object(_lowering, "constrain_to_fx_strides", _patched)
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
@@ -312,7 +354,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"torch._inductor.compile_fx._recursive_pre_grad_passes",
|
||||
lambda gm, _: gm,
|
||||
)
|
||||
with ctx:
|
||||
with ctx, _patch_constrain_to_fx_strides():
|
||||
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
|
||||
|
||||
if use_aot:
|
||||
@@ -555,6 +597,7 @@ class InductorAdaptor(CompilerInterface):
|
||||
stack.enter_context(
|
||||
torch._functorch.config.patch(enable_remote_autograd_cache=False)
|
||||
)
|
||||
stack.enter_context(_patch_constrain_to_fx_strides())
|
||||
|
||||
compiled_graph = compile_fx(
|
||||
graph,
|
||||
|
||||
Reference in New Issue
Block a user