[torch.compile] Undo the fast_moe_cold_start hack in torch>=2.11 (#35475)
Signed-off-by: Richard Zou <zou3519@gmail.com>
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from contextlib import nullcontext
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -30,6 +31,8 @@ from vllm.model_executor.layers.fused_moe.runner.moe_runner import MoERunner
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.torch_utils import (
|
||||
HAS_OPAQUE_TYPE,
|
||||
ModuleName,
|
||||
aux_stream,
|
||||
current_stream,
|
||||
direct_register_custom_op,
|
||||
@@ -56,13 +59,27 @@ def get_layer_from_name(layer_name: str) -> torch.nn.Module:
|
||||
return forward_context.no_compile_layers[layer_name]
|
||||
|
||||
|
||||
# On torch >= 2.11, layer_name is a hoisted ModuleName opaque object;
|
||||
# on older versions it remains a plain str.
|
||||
if TYPE_CHECKING:
|
||||
from typing import TypeAlias
|
||||
|
||||
_layer_name_type: TypeAlias = str | ModuleName
|
||||
else:
|
||||
_layer_name_type = ModuleName if HAS_OPAQUE_TYPE else str
|
||||
|
||||
|
||||
def _resolve_layer_name(layer_name: str | ModuleName) -> str:
|
||||
return layer_name.value if isinstance(layer_name, ModuleName) else layer_name
|
||||
|
||||
|
||||
def _moe_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
@@ -74,7 +91,7 @@ def _moe_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -83,9 +100,9 @@ def _moe_forward_shared(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
layer = get_layer_from_name(layer_name)
|
||||
layer = get_layer_from_name(_resolve_layer_name(layer_name))
|
||||
# TODO(bnell): this can be removed after MK migration is complete.
|
||||
layer.ensure_moe_quant_config_init()
|
||||
return layer.runner.forward_impl(
|
||||
@@ -97,7 +114,7 @@ def _moe_forward_shared_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
shared_experts_input: torch.Tensor | None,
|
||||
layer_name: str,
|
||||
layer_name: _layer_name_type,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Output shapes:
|
||||
# - fused_out: same as hidden_states (routed experts use transformed size)
|
||||
@@ -105,12 +122,10 @@ def _moe_forward_shared_fake(
|
||||
# hidden_states
|
||||
# (For latent MoE: shared experts use original hidden_size, not latent size)
|
||||
fused_out = torch.empty_like(hidden_states)
|
||||
|
||||
if shared_experts_input is not None:
|
||||
shared_out = torch.empty_like(shared_experts_input)
|
||||
else:
|
||||
shared_out = torch.empty_like(hidden_states)
|
||||
|
||||
return shared_out, fused_out
|
||||
|
||||
|
||||
@@ -367,7 +382,9 @@ class DefaultMoERunner(MoERunner):
|
||||
assert len(trunc_sizes) == 1
|
||||
return func(states, trunc_sizes[0])
|
||||
|
||||
def _encode_layer_name(self) -> str:
|
||||
def _encode_layer_name(self) -> str | ModuleName:
|
||||
if HAS_OPAQUE_TYPE:
|
||||
return ModuleName(self.layer_name)
|
||||
# Can be unavailable or None in unittests
|
||||
if (
|
||||
is_forward_context_available()
|
||||
|
||||
Reference in New Issue
Block a user