[compile] Stop unconditionally patching constrain_to_fx_strides (#36152)

Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2026-03-06 10:17:27 -05:00
committed by GitHub
parent 39f9ea0da4
commit 54756b6109
2 changed files with 44 additions and 42 deletions

View File

@@ -225,6 +225,48 @@ def _patch_standalone_compile_atomic_save() -> None:
logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__)
def _patch_constrain_to_fx_strides() -> contextlib.AbstractContextManager:
"""Context manager that patches inductor's ``constrain_to_fx_strides``
to handle opaque (non-tensor) arguments.
The original calls ``.stride()`` on every FX arg's meta value, which
crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted
opaque types). The patched version skips args whose meta value is
not a ``torch.Tensor``.
Returns ``nullcontext`` on torch < 2.11.
Upstream issue: https://github.com/pytorch/pytorch/issues/175973
"""
if not is_torch_equal_or_newer("2.11.0.dev"):
return contextlib.nullcontext()
import torch._inductor.ir as _ir
import torch._inductor.lowering as _lowering
from torch._inductor.virtualized import V as _V
def _patched(fx_node, *args, **kwargs):
def apply_constraint(arg, fx_arg):
if isinstance(arg, _ir.IRNode):
meta_val = fx_arg.meta.get("val")
if isinstance(meta_val, torch.Tensor):
stride_order = _ir.get_stride_order(
meta_val.stride(), _V.graph.sizevars.shape_env
)
return _ir.ExternKernel.require_stride_order(arg, stride_order)
return arg
if isinstance(arg, dict):
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
return arg
args = tuple(
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
)
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
return args, kwargs
return patch.object(_lowering, "constrain_to_fx_strides", _patched)
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
@@ -312,7 +354,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
"torch._inductor.compile_fx._recursive_pre_grad_passes",
lambda gm, _: gm,
)
with ctx:
with ctx, _patch_constrain_to_fx_strides():
compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs)
if use_aot:
@@ -555,6 +597,7 @@ class InductorAdaptor(CompilerInterface):
stack.enter_context(
torch._functorch.config.patch(enable_remote_autograd_cache=False)
)
stack.enter_context(_patch_constrain_to_fx_strides())
compiled_graph = compile_fx(
graph,