diff --git a/vllm/compilation/compiler_interface.py b/vllm/compilation/compiler_interface.py index e7748e380..035370063 100644 --- a/vllm/compilation/compiler_interface.py +++ b/vllm/compilation/compiler_interface.py @@ -225,6 +225,48 @@ def _patch_standalone_compile_atomic_save() -> None: logger.debug("Patched %s.save for atomic writes (torch < 2.10)", cls.__name__) +def _patch_constrain_to_fx_strides() -> contextlib.AbstractContextManager: + """Context manager that patches inductor's ``constrain_to_fx_strides`` + to handle opaque (non-tensor) arguments. + + The original calls ``.stride()`` on every FX arg's meta value, which + crashes on ``FakeScriptObject`` (the compile-time proxy for hoisted + opaque types). The patched version skips args whose meta value is + not a ``torch.Tensor``. + + Returns ``nullcontext`` on torch < 2.11. + Upstream issue: https://github.com/pytorch/pytorch/issues/175973 + """ + if not is_torch_equal_or_newer("2.11.0.dev"): + return contextlib.nullcontext() + + import torch._inductor.ir as _ir + import torch._inductor.lowering as _lowering + from torch._inductor.virtualized import V as _V + + def _patched(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, _ir.IRNode): + meta_val = fx_arg.meta.get("val") + if isinstance(meta_val, torch.Tensor): + stride_order = _ir.get_stride_order( + meta_val.stride(), _V.graph.sizevars.shape_env + ) + return _ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + return patch.object(_lowering, "constrain_to_fx_strides", _patched) + + class InductorStandaloneAdaptor(CompilerInterface): """ The adaptor for the Inductor compiler. @@ -312,7 +354,7 @@ class InductorStandaloneAdaptor(CompilerInterface): "torch._inductor.compile_fx._recursive_pre_grad_passes", lambda gm, _: gm, ) - with ctx: + with ctx, _patch_constrain_to_fx_strides(): compiled_graph = standalone_compile(graph, example_inputs, **compile_kwargs) if use_aot: @@ -555,6 +597,7 @@ class InductorAdaptor(CompilerInterface): stack.enter_context( torch._functorch.config.patch(enable_remote_autograd_cache=False) ) + stack.enter_context(_patch_constrain_to_fx_strides()) compiled_graph = compile_fx( graph, diff --git a/vllm/env_override.py b/vllm/env_override.py index 27992218f..181d000a6 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -482,44 +482,3 @@ if is_torch_equal("2.9.0"): PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched GraphLowering._update_scheduler = _update_scheduler_patched - -# =================================================== -# torch 2.11 Inductor constrain_to_fx_strides monkeypatch -# =================================================== -# Patch the inductor's `constrain_to_fx_strides` to handle opaque -# (non-tensor) arguments. The original calls `.stride()` on every FX -# arg's meta value, which crashes on FakeScriptObject (the compile-time -# proxy for hoisted opaque types). The patched version skips args -# whose meta value is not a torch.Tensor. -# Upstream issue: https://github.com/pytorch/pytorch/issues/175973 - -from vllm.utils.torch_utils import is_torch_equal_or_newer - -if is_torch_equal_or_newer("2.11.0.dev"): - import torch._inductor.ir as _ir - import torch._inductor.lowering as _lowering - from torch._inductor.virtualized import V as _V - - _orig_constrain = _lowering.constrain_to_fx_strides - - def _patched_constrain_to_fx_strides(fx_node, *args, **kwargs): - def apply_constraint(arg, fx_arg): - if isinstance(arg, _ir.IRNode): - meta_val = fx_arg.meta.get("val") - if isinstance(meta_val, torch.Tensor): - stride_order = _ir.get_stride_order( - meta_val.stride(), _V.graph.sizevars.shape_env - ) - return _ir.ExternKernel.require_stride_order(arg, stride_order) - return arg - if isinstance(arg, dict): - return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} - return arg - - args = tuple( - apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) - ) - kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} - return args, kwargs - - _lowering.constrain_to_fx_strides = _patched_constrain_to_fx_strides