From e82fbeec7b360af4fb908bf67a659b22f93266d3 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Sun, 1 Mar 2026 16:44:22 -0500 Subject: [PATCH] [torch.compile] Undo the fast_moe_cold_start hack in torch>=2.11 (#35475) Signed-off-by: Richard Zou --- vllm/config/vllm.py | 8 +++- vllm/env_override.py | 41 +++++++++++++++++++ .../fused_moe/runner/default_moe_runner.py | 35 ++++++++++++---- vllm/utils/torch_utils.py | 35 ++++++++++++++++ 4 files changed, 109 insertions(+), 10 deletions(-) diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index 7f7b21316..d781d778e 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -883,7 +883,13 @@ class VllmConfig: self.compilation_config.pass_config.enable_sp = False self.compilation_config.pass_config.fuse_gemm_comms = False - if self.compilation_config.fast_moe_cold_start is None: + from vllm.utils.torch_utils import HAS_OPAQUE_TYPE + + if HAS_OPAQUE_TYPE: + # On torch >= 2.11 the hoisted OpaqueObject approach supersedes + # fast_moe_cold_start, so force it off. + self.compilation_config.fast_moe_cold_start = False + elif self.compilation_config.fast_moe_cold_start is None: # resolve default behavior: try to be as safe as possible # this config is unsafe if any spec decoding draft model has a MOE. # We'll conservatively turn it off if we see spec decoding. diff --git a/vllm/env_override.py b/vllm/env_override.py index 181d000a6..27992218f 100644 --- a/vllm/env_override.py +++ b/vllm/env_override.py @@ -482,3 +482,44 @@ if is_torch_equal("2.9.0"): PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched GraphLowering._update_scheduler = _update_scheduler_patched + +# =================================================== +# torch 2.11 Inductor constrain_to_fx_strides monkeypatch +# =================================================== +# Patch the inductor's `constrain_to_fx_strides` to handle opaque +# (non-tensor) arguments. The original calls `.stride()` on every FX +# arg's meta value, which crashes on FakeScriptObject (the compile-time +# proxy for hoisted opaque types). The patched version skips args +# whose meta value is not a torch.Tensor. +# Upstream issue: https://github.com/pytorch/pytorch/issues/175973 + +from vllm.utils.torch_utils import is_torch_equal_or_newer + +if is_torch_equal_or_newer("2.11.0.dev"): + import torch._inductor.ir as _ir + import torch._inductor.lowering as _lowering + from torch._inductor.virtualized import V as _V + + _orig_constrain = _lowering.constrain_to_fx_strides + + def _patched_constrain_to_fx_strides(fx_node, *args, **kwargs): + def apply_constraint(arg, fx_arg): + if isinstance(arg, _ir.IRNode): + meta_val = fx_arg.meta.get("val") + if isinstance(meta_val, torch.Tensor): + stride_order = _ir.get_stride_order( + meta_val.stride(), _V.graph.sizevars.shape_env + ) + return _ir.ExternKernel.require_stride_order(arg, stride_order) + return arg + if isinstance(arg, dict): + return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg} + return arg + + args = tuple( + apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) + ) + kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} + return args, kwargs + + _lowering.constrain_to_fx_strides = _patched_constrain_to_fx_strides diff --git a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py index 9c2adf799..274929c07 100644 --- a/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/default_moe_runner.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from contextlib import nullcontext +from typing import TYPE_CHECKING import torch import torch.nn.functional as F @@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv from vllm.utils.torch_utils import ( + HAS_OPAQUE_TYPE, + ModuleName, aux_stream, current_stream, direct_register_custom_op, @@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module: return forward_context.no_compile_layers[layer_name] +# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object; +# on older versions it remains a plain str. +if TYPE_CHECKING: + from typing import TypeAlias + + _layer_name_type: TypeAlias = str | ModuleName +else: + _layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str + + +def _resolve_layer_name(layer_name: str | ModuleName) -> str: + return layer_name.value if isinstance(layer_name, ModuleName) else layer_name + + def _moe_forward( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> torch.Tensor: - layer = get_layer_from_name(layer_name) + layer = get_layer_from_name(_resolve_layer_name(layer_name)) # TODO(bnell): this can be removed after MK migration is complete. layer.ensure_moe_quant_config_init() return layer.runner.forward_impl( @@ -74,7 +91,7 @@ def _moe_forward_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> torch.Tensor: return torch.empty_like(hidden_states) @@ -83,9 +100,9 @@ def _moe_forward_shared( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> tuple[torch.Tensor, torch.Tensor]: - layer = get_layer_from_name(layer_name) + layer = get_layer_from_name(_resolve_layer_name(layer_name)) # TODO(bnell): this can be removed after MK migration is complete. layer.ensure_moe_quant_config_init() return layer.runner.forward_impl( @@ -97,7 +114,7 @@ def _moe_forward_shared_fake( hidden_states: torch.Tensor, router_logits: torch.Tensor, shared_experts_input: torch.Tensor | None, - layer_name: str, + layer_name: _layer_name_type, ) -> tuple[torch.Tensor, torch.Tensor]: # Output shapes: # - fused_out: same as hidden_states (routed experts use transformed size) @@ -105,12 +122,10 @@ def _moe_forward_shared_fake( # hidden_states # (For latent MoE: shared experts use original hidden_size, not latent size) fused_out = torch.empty_like(hidden_states) - if shared_experts_input is not None: shared_out = torch.empty_like(shared_experts_input) else: shared_out = torch.empty_like(hidden_states) - return shared_out, fused_out @@ -367,7 +382,9 @@ class DefaultMoERunner(MoERunner): assert len(trunc_sizes) == 1 return func(states, trunc_sizes[0]) - def _encode_layer_name(self) -> str: + def _encode_layer_name(self) -> str | ModuleName: + if HAS_OPAQUE_TYPE: + return ModuleName(self.layer_name) # Can be unavailable or None in unittests if ( is_forward_context_available() diff --git a/vllm/utils/torch_utils.py b/vllm/utils/torch_utils.py index e834108ca..e4aa4fe61 100644 --- a/vllm/utils/torch_utils.py +++ b/vllm/utils/torch_utils.py @@ -740,6 +740,41 @@ def is_torch_equal(target: str) -> bool: return Version(importlib.metadata.version("torch")) == Version(target) +HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev") + +if HAS_OPAQUE_TYPE: + from torch._opaque_base import OpaqueBase +else: + OpaqueBase = object # type: ignore[misc, assignment] + + +class ModuleName(OpaqueBase): # type: ignore[misc] + """Wraps a module name string for use as a torch opaque type. + + When torch >= 2.11, this is registered as a hoisted value-type opaque + object so that torch.compile lifts it as a graph input instead of baking + it as a constant. This avoids per-layer recompilation for MOE ops. + """ + + def __init__(self, value: str): + self.value = value + + def __eq__(self, other): + return isinstance(other, ModuleName) and self.value == other.value + + def __hash__(self): + return hash(self.value) + + def __fx_repr__(self): + return (f"ModuleName({self.value!r})", {ModuleName}) + + +if HAS_OPAQUE_TYPE: + from torch._library.opaque_object import register_opaque_type + + register_opaque_type(ModuleName, typ="value", hoist=True) + + # Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform def supports_xccl() -> bool: return torch.distributed.is_xccl_available()