[torch.compile] Undo the fast_moe_cold_start hack in torch>=2.11 (#35475)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -883,7 +883,13 @@ class VllmConfig:
|
|||||||
self.compilation_config.pass_config.enable_sp = False
|
self.compilation_config.pass_config.enable_sp = False
|
||||||
self.compilation_config.pass_config.fuse_gemm_comms = False
|
self.compilation_config.pass_config.fuse_gemm_comms = False
|
||||||
|
|
||||||
if self.compilation_config.fast_moe_cold_start is None:
|
from vllm.utils.torch_utils import HAS_OPAQUE_TYPE
|
||||||
|
|
||||||
|
if HAS_OPAQUE_TYPE:
|
||||||
|
# On torch >= 2.11 the hoisted OpaqueObject approach supersedes
|
||||||
|
# fast_moe_cold_start, so force it off.
|
||||||
|
self.compilation_config.fast_moe_cold_start = False
|
||||||
|
elif self.compilation_config.fast_moe_cold_start is None:
|
||||||
# resolve default behavior: try to be as safe as possible
|
# resolve default behavior: try to be as safe as possible
|
||||||
# this config is unsafe if any spec decoding draft model has a MOE.
|
# this config is unsafe if any spec decoding draft model has a MOE.
|
||||||
# We'll conservatively turn it off if we see spec decoding.
|
# We'll conservatively turn it off if we see spec decoding.
|
||||||
|
|||||||
@@ -482,3 +482,44 @@ if is_torch_equal("2.9.0"):
|
|||||||
|
|
||||||
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
PythonWrapperCodegen.memory_plan_reuse = memory_plan_reuse_patched
|
||||||
GraphLowering._update_scheduler = _update_scheduler_patched
|
GraphLowering._update_scheduler = _update_scheduler_patched
|
||||||
|
|
||||||
|
# ===================================================
|
||||||
|
# torch 2.11 Inductor constrain_to_fx_strides monkeypatch
|
||||||
|
# ===================================================
|
||||||
|
# Patch the inductor's `constrain_to_fx_strides` to handle opaque
|
||||||
|
# (non-tensor) arguments. The original calls `.stride()` on every FX
|
||||||
|
# arg's meta value, which crashes on FakeScriptObject (the compile-time
|
||||||
|
# proxy for hoisted opaque types). The patched version skips args
|
||||||
|
# whose meta value is not a torch.Tensor.
|
||||||
|
# Upstream issue: https://github.com/pytorch/pytorch/issues/175973
|
||||||
|
|
||||||
|
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||||
|
|
||||||
|
if is_torch_equal_or_newer("2.11.0.dev"):
|
||||||
|
import torch._inductor.ir as _ir
|
||||||
|
import torch._inductor.lowering as _lowering
|
||||||
|
from torch._inductor.virtualized import V as _V
|
||||||
|
|
||||||
|
_orig_constrain = _lowering.constrain_to_fx_strides
|
||||||
|
|
||||||
|
def _patched_constrain_to_fx_strides(fx_node, *args, **kwargs):
|
||||||
|
def apply_constraint(arg, fx_arg):
|
||||||
|
if isinstance(arg, _ir.IRNode):
|
||||||
|
meta_val = fx_arg.meta.get("val")
|
||||||
|
if isinstance(meta_val, torch.Tensor):
|
||||||
|
stride_order = _ir.get_stride_order(
|
||||||
|
meta_val.stride(), _V.graph.sizevars.shape_env
|
||||||
|
)
|
||||||
|
return _ir.ExternKernel.require_stride_order(arg, stride_order)
|
||||||
|
return arg
|
||||||
|
if isinstance(arg, dict):
|
||||||
|
return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg}
|
||||||
|
return arg
|
||||||
|
|
||||||
|
args = tuple(
|
||||||
|
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
|
||||||
|
)
|
||||||
|
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
|
||||||
|
return args, kwargs
|
||||||
|
|
||||||
|
_lowering.constrain_to_fx_strides = _patched_constrain_to_fx_strides
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
|||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.torch_utils import (
|
from vllm.utils.torch_utils import (
|
||||||
|
HAS_OPAQUE_TYPE,
|
||||||
|
ModuleName,
|
||||||
aux_stream,
|
aux_stream,
|
||||||
current_stream,
|
current_stream,
|
||||||
direct_register_custom_op,
|
direct_register_custom_op,
|
||||||
@@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
|
|||||||
return forward_context.no_compile_layers[layer_name]
|
return forward_context.no_compile_layers[layer_name]
|
||||||
|
|
||||||
|
|
||||||
|
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
|
||||||
|
# on older versions it remains a plain str.
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import TypeAlias
|
||||||
|
|
||||||
|
_layer_name_type: TypeAlias = str | ModuleName
|
||||||
|
else:
|
||||||
|
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
|
||||||
|
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
||||||
|
|
||||||
|
|
||||||
def _moe_forward(
|
def _moe_forward(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
layer_name: str,
|
layer_name: _layer_name_type,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
layer = get_layer_from_name(layer_name)
|
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||||
# TODO(bnell): this can be removed after MK migration is complete.
|
# TODO(bnell): this can be removed after MK migration is complete.
|
||||||
layer.ensure_moe_quant_config_init()
|
layer.ensure_moe_quant_config_init()
|
||||||
return layer.runner.forward_impl(
|
return layer.runner.forward_impl(
|
||||||
@@ -74,7 +91,7 @@ def _moe_forward_fake(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
layer_name: str,
|
layer_name: _layer_name_type,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.empty_like(hidden_states)
|
return torch.empty_like(hidden_states)
|
||||||
|
|
||||||
@@ -83,9 +100,9 @@ def _moe_forward_shared(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
layer_name: str,
|
layer_name: _layer_name_type,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
layer = get_layer_from_name(layer_name)
|
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||||
# TODO(bnell): this can be removed after MK migration is complete.
|
# TODO(bnell): this can be removed after MK migration is complete.
|
||||||
layer.ensure_moe_quant_config_init()
|
layer.ensure_moe_quant_config_init()
|
||||||
return layer.runner.forward_impl(
|
return layer.runner.forward_impl(
|
||||||
@@ -97,7 +114,7 @@ def _moe_forward_shared_fake(
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
shared_experts_input: torch.Tensor | None,
|
shared_experts_input: torch.Tensor | None,
|
||||||
layer_name: str,
|
layer_name: _layer_name_type,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
# Output shapes:
|
# Output shapes:
|
||||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||||
@@ -105,12 +122,10 @@ def _moe_forward_shared_fake(
|
|||||||
# hidden_states
|
# hidden_states
|
||||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||||
fused_out = torch.empty_like(hidden_states)
|
fused_out = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
if shared_experts_input is not None:
|
if shared_experts_input is not None:
|
||||||
shared_out = torch.empty_like(shared_experts_input)
|
shared_out = torch.empty_like(shared_experts_input)
|
||||||
else:
|
else:
|
||||||
shared_out = torch.empty_like(hidden_states)
|
shared_out = torch.empty_like(hidden_states)
|
||||||
|
|
||||||
return shared_out, fused_out
|
return shared_out, fused_out
|
||||||
|
|
||||||
|
|
||||||
@@ -367,7 +382,9 @@ class DefaultMoERunner(MoERunner):
|
|||||||
assert len(trunc_sizes) == 1
|
assert len(trunc_sizes) == 1
|
||||||
return func(states, trunc_sizes[0])
|
return func(states, trunc_sizes[0])
|
||||||
|
|
||||||
def _encode_layer_name(self) -> str:
|
def _encode_layer_name(self) -> str | ModuleName:
|
||||||
|
if HAS_OPAQUE_TYPE:
|
||||||
|
return ModuleName(self.layer_name)
|
||||||
# Can be unavailable or None in unittests
|
# Can be unavailable or None in unittests
|
||||||
if (
|
if (
|
||||||
is_forward_context_available()
|
is_forward_context_available()
|
||||||
|
|||||||
@@ -740,6 +740,41 @@ def is_torch_equal(target: str) -> bool:
|
|||||||
return Version(importlib.metadata.version("torch")) == Version(target)
|
return Version(importlib.metadata.version("torch")) == Version(target)
|
||||||
|
|
||||||
|
|
||||||
|
HAS_OPAQUE_TYPE = is_torch_equal_or_newer("2.11.0.dev")
|
||||||
|
|
||||||
|
if HAS_OPAQUE_TYPE:
|
||||||
|
from torch._opaque_base import OpaqueBase
|
||||||
|
else:
|
||||||
|
OpaqueBase = object # type: ignore[misc, assignment]
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleName(OpaqueBase): # type: ignore[misc]
|
||||||
|
"""Wraps a module name string for use as a torch opaque type.
|
||||||
|
|
||||||
|
When torch >= 2.11, this is registered as a hoisted value-type opaque
|
||||||
|
object so that torch.compile lifts it as a graph input instead of baking
|
||||||
|
it as a constant. This avoids per-layer recompilation for MOE ops.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value: str):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return isinstance(other, ModuleName) and self.value == other.value
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return hash(self.value)
|
||||||
|
|
||||||
|
def __fx_repr__(self):
|
||||||
|
return (f"ModuleName({self.value!r})", {ModuleName})
|
||||||
|
|
||||||
|
|
||||||
|
if HAS_OPAQUE_TYPE:
|
||||||
|
from torch._library.opaque_object import register_opaque_type
|
||||||
|
|
||||||
|
register_opaque_type(ModuleName, typ="value", hoist=True)
|
||||||
|
|
||||||
|
|
||||||
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
|
# Supports xccl with PyTorch versions >= 2.8.0.dev for XPU platform
|
||||||
def supports_xccl() -> bool:
|
def supports_xccl() -> bool:
|
||||||
return torch.distributed.is_xccl_available()
|
return torch.distributed.is_xccl_available()
|
||||||
|
|||||||
Reference in New Issue
Block a user