[Refactor] Replace activation: str with MoEActivation enum (#33843)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -11,15 +11,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
GELU_NO_MUL,
|
||||
RELU2_NO_MUL,
|
||||
SILU_NO_MUL,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Test parameters
|
||||
@@ -28,7 +24,11 @@ N_SIZES = [128, 256]
|
||||
K_SIZES = [64, 128]
|
||||
TOPK_VALUES = [1, 2]
|
||||
NUM_EXPERTS = 8
|
||||
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
|
||||
NO_MUL_ACTIVATIONS = [
|
||||
MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
]
|
||||
|
||||
|
||||
def make_test_tensors(
|
||||
@@ -73,7 +73,7 @@ def test_triton_experts_no_mul_activation(
|
||||
n: int,
|
||||
k: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
):
|
||||
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
|
||||
m, n, k, NUM_EXPERTS, topk
|
||||
@@ -161,11 +161,11 @@ def test_workspace_shapes_no_mul_vs_gated():
|
||||
)
|
||||
|
||||
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
|
||||
M, N, K, topk, 8, 8, None, SILU_NO_MUL
|
||||
M, N, K, topk, 8, 8, None, MoEActivation.SILU_NO_MUL
|
||||
)
|
||||
|
||||
ws1_gated, _, out_gated = experts.workspace_shapes(
|
||||
M, N, K, topk, 8, 8, None, "silu"
|
||||
M, N, K, topk, 8, 8, None, MoEActivation.SILU
|
||||
)
|
||||
|
||||
# For no_mul: activation_out_dim = N
|
||||
@@ -202,10 +202,10 @@ def test_adjust_n_for_activation():
|
||||
N = 256
|
||||
|
||||
# Gated activations should return N // 2
|
||||
assert experts.adjust_N_for_activation(N, "silu") == N // 2
|
||||
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.SILU) == N // 2
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.GELU) == N // 2
|
||||
|
||||
# Non-gated activations should return N
|
||||
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.SILU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.GELU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.RELU2_NO_MUL) == N
|
||||
|
||||
Reference in New Issue
Block a user