[Refactor] Replace activation: str with MoEActivation enum (#33843)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-02-11 20:29:32 -05:00
committed by GitHub
parent 83b47f67b1
commit ff1f83b056
48 changed files with 474 additions and 282 deletions

View File

@@ -11,15 +11,11 @@ import pytest
import torch
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.utils import (
GELU_NO_MUL,
RELU2_NO_MUL,
SILU_NO_MUL,
)
from vllm.platforms import current_platform
# Test parameters
@@ -28,7 +24,11 @@ N_SIZES = [128, 256]
K_SIZES = [64, 128]
TOPK_VALUES = [1, 2]
NUM_EXPERTS = 8
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
NO_MUL_ACTIVATIONS = [
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
def make_test_tensors(
@@ -73,7 +73,7 @@ def test_triton_experts_no_mul_activation(
n: int,
k: int,
topk: int,
activation: str,
activation: MoEActivation,
):
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
m, n, k, NUM_EXPERTS, topk
@@ -161,11 +161,11 @@ def test_workspace_shapes_no_mul_vs_gated():
)
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, SILU_NO_MUL
M, N, K, topk, 8, 8, None, MoEActivation.SILU_NO_MUL
)
ws1_gated, _, out_gated = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, "silu"
M, N, K, topk, 8, 8, None, MoEActivation.SILU
)
# For no_mul: activation_out_dim = N
@@ -202,10 +202,10 @@ def test_adjust_n_for_activation():
N = 256
# Gated activations should return N // 2
assert experts.adjust_N_for_activation(N, "silu") == N // 2
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
assert experts.adjust_N_for_activation(N, MoEActivation.SILU) == N // 2
assert experts.adjust_N_for_activation(N, MoEActivation.GELU) == N // 2
# Non-gated activations should return N
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.SILU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.GELU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.RELU2_NO_MUL) == N