[Refactor] Replace activation: str with MoEActivation enum (#33843)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-02-11 20:29:32 -05:00
committed by GitHub
parent 83b47f67b1
commit ff1f83b056
48 changed files with 474 additions and 282 deletions

View File

@@ -12,6 +12,10 @@ import torch
import vllm.envs as envs
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.activation import (
MoEActivation,
apply_moe_activation,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -19,7 +23,6 @@ from vllm.model_executor.layers.fused_moe.config import (
)
from vllm.model_executor.layers.fused_moe.utils import (
_resize_cache,
apply_moe_activation,
count_expert_num_tokens,
disable_inplace,
)
@@ -536,7 +539,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
@staticmethod
@abstractmethod
def _supports_activation(activation: str) -> bool:
def _supports_activation(activation: MoEActivation) -> bool:
"""
Whether the kernel supports a particular act function.
"""
@@ -658,7 +661,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
"""
Compute the shapes for the temporary and final outputs of the two gemms
@@ -690,7 +693,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
raise NotImplementedError
@staticmethod
def adjust_N_for_activation(N: int, activation: str) -> int:
def adjust_N_for_activation(N: int, activation: MoEActivation) -> int:
"""
Calculate the output dimension for the activation function.
@@ -702,16 +705,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
Args:
N: The intermediate size (width of w1/w3 weights).
activation: The activation function name.
activation: The activation function enum.
Returns:
The output dimension after activation.
"""
is_no_mul = activation.endswith("_no_mul")
return N if is_no_mul else N // 2
return N if not activation.is_gated else N // 2
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
self, activation: MoEActivation, output: torch.Tensor, input: torch.Tensor
) -> None:
apply_moe_activation(activation, output, input)
@@ -732,7 +734,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
@@ -892,7 +894,7 @@ class FusedMoEModularKernel(torch.nn.Module):
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: ExpertTokensMetadata | None,
activation: str,
activation: MoEActivation,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Allocate temporary and output buffers for the fused experts op.
@@ -1135,7 +1137,7 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
global_num_experts: int,
local_num_experts: int,
expert_map: torch.Tensor | None,
@@ -1309,7 +1311,7 @@ class FusedMoEModularKernel(torch.nn.Module):
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
activation: MoEActivation = MoEActivation.SILU,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
@@ -1326,7 +1328,7 @@ class FusedMoEModularKernel(torch.nn.Module):
- topk_weights (torch.Tensor): The topk weights applied at the end of
the layer.
- topk_ids (torch.Tensor): A map of row to expert id.
- activation (str): The activation function to apply after the first
- activation (MoEActivation): The activation function to apply after the first
MoE layer.
- global_num_experts (int): The total number of experts in the global
expert space.