[Refactor] Replace activation: str with MoEActivation enum (#33843)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-02-11 20:29:32 -05:00
committed by GitHub
parent 83b47f67b1
commit ff1f83b056
48 changed files with 474 additions and 282 deletions

View File

@@ -22,6 +22,7 @@ from vllm.distributed import (
)
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.all2all_utils import (
maybe_make_prepare_finalize,
)
@@ -599,7 +600,7 @@ def make_modular_kernel(
moe_parallel_config=moe_parallel_config,
in_dtype=config.dtype,
max_num_tokens=next_power_of_2(config.M),
activation="silu",
activation=MoEActivation.SILU,
device=vllm_config.device_config.device,
routing_method=RoutingMethodType.DeepSeekV3,
)

View File

@@ -6,6 +6,7 @@ import torch
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT_FN
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
@@ -19,7 +20,7 @@ EXPERT_NUM = [
HIDDEN_DIM = [128, 2880]
INTERMEDIATE_DIM = [128, 2880]
BATCH_SIZE = [1, 64, 256]
ACT = ["silu", "swigluoai"]
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
USE_BIAS = [True, False]
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
DTYPE = [torch.bfloat16]
@@ -33,7 +34,7 @@ def ref_fused_moe(
w2_bias: torch.Tensor | None,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
activation: MoEActivation,
) -> torch.Tensor:
len_experts = w13.size(0)
@@ -103,7 +104,7 @@ def test_cpu_fused_moe(
intermediate_size: int,
use_bias: bool,
dtype: torch.dtype,
act: str,
act: MoEActivation,
isa: str,
):
set_random_seed(0)
@@ -153,7 +154,7 @@ def test_cpu_fused_moe(
w2_bias,
topk_weight,
topk_ids,
act,
act.value,
isa,
)

View File

@@ -12,6 +12,7 @@ from tests.kernels.moe.utils import make_dummy_moe_config
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
FusedMoEQuantConfig,
@@ -531,7 +532,7 @@ def test_run_cutlass_moe_fp8(
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
activation = "silu"
activation = MoEActivation.SILU
a1q, a1q_scale = moe_kernel_quantize_input(
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
)

View File

@@ -16,6 +16,7 @@ from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
@@ -324,7 +325,7 @@ def deepep_deepgemm_moe_impl(
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
activation="silu",
activation=MoEActivation.SILU,
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,

View File

@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
)
@@ -260,7 +261,7 @@ def deep_ep_moe_impl(
w2=w2,
topk_weights=topk_weights_chunk,
topk_ids=topk_chunk,
activation="silu",
activation=MoEActivation.SILU,
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,

View File

@@ -7,6 +7,7 @@ import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -93,9 +94,14 @@ class TestData:
@staticmethod
def make_moe_tensors_8bit(
m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
m: int,
k: int,
n: int,
e: int,
is_trtllm: bool,
activation: MoEActivation = MoEActivation.SILU,
) -> "TestData":
is_gated = activation != "relu2_no_mul"
is_gated = activation.is_gated
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
w13 = torch.randn(
@@ -194,7 +200,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=False,
activation="silu",
activation=MoEActivation.SILU,
global_num_experts=e,
expert_map=None,
apply_router_weight_on_input=True,
@@ -219,21 +225,19 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
def test_flashinfer_cutlass_moe_fp8_no_graph(
m: int,
n: int,
k: int,
e: int,
topk: int,
activation: str,
activation: MoEActivation,
monkeypatch,
workspace_init,
):
set_random_seed(7)
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
assert activation in ["silu", "relu2_no_mul"]
is_act_and_mul = activation == "silu_and_mul"
with set_current_vllm_config(vllm_config):
td = TestData.make_moe_tensors_8bit(
m, k, n, e, is_trtllm=False, activation=activation
@@ -292,7 +296,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
in_dtype=torch.bfloat16,
is_act_and_mul=is_act_and_mul,
is_act_and_mul=activation.is_gated,
routing_method=RoutingMethodType.TopK,
)

View File

@@ -13,6 +13,7 @@ from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -54,7 +55,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(
m: int,
@@ -63,7 +64,7 @@ def test_flashinfer_fp4_moe_no_graph(
e: int,
topk: int,
dtype: torch.dtype,
activation: str,
activation: MoEActivation,
workspace_init,
):
set_random_seed(7)
@@ -73,7 +74,7 @@ def test_flashinfer_fp4_moe_no_graph(
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
is_gated_act = activation == "silu_and_mul"
is_gated_act = activation.is_gated
w1_q, w2_q, quant_config = make_test_quant_config(
e,
@@ -112,15 +113,13 @@ def test_flashinfer_fp4_moe_no_graph(
inplace=False,
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=fi_activation,
activation=activation,
)
# Reference check:

View File

@@ -7,6 +7,7 @@ Test modular OAI Triton MoE
import pytest
import torch
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.utils.import_utils import has_triton_kernels
if not has_triton_kernels():
@@ -192,7 +193,7 @@ def oai_triton_moe_impl(
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation="swigluoai",
activation=MoEActivation.SWIGLUOAI,
global_num_experts=num_experts,
expert_map=None,
apply_router_weight_on_input=False,

View File

@@ -29,6 +29,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.forward_context import get_forward_context, set_forward_context
from vllm.model_executor.layers.fused_moe import (
MoEActivation,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.config import (
@@ -1155,7 +1156,10 @@ def test_fused_marlin_moe_with_bias(m):
@pytest.mark.parametrize("m", [1, 64, 256])
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
@pytest.mark.parametrize("activation", [MoEActivation.RELU2_NO_MUL])
def test_fused_marlin_moe_non_gated(
m: int, n: int, k: int, e: int, topk: int, activation: MoEActivation
):
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
Non-gated activations like relu2 don't have the gate-up projection pattern,
@@ -1198,7 +1202,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
w2_data.w_ref,
score,
topk,
activation="relu2",
activation=activation,
)
marlin_output = fused_marlin_moe(
@@ -1223,7 +1227,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
w2_zeros=w2_data.zeros,
quant_type_id=quant_type.id,
is_k_full=is_k_full,
activation="relu2_no_mul",
activation=activation,
)
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
@@ -1330,9 +1334,18 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("with_bias", [False, True])
@pytest.mark.parametrize("activation", ["silu"])
@pytest.mark.parametrize("activation", [MoEActivation.SILU])
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
def test_cpu_fused_moe_basic(
m: int,
n: int,
k: int,
e: int,
topk: int,
dtype: torch.dtype,
with_bias: bool,
activation: MoEActivation,
):
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
device = "cpu"
@@ -1608,6 +1621,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
hidden_dim=k,
intermediate_size_per_partition=n,
num_local_experts=e,
num_logical_experts=e,
activation="silu",
device="cuda",
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),

View File

@@ -9,6 +9,7 @@ from tests.kernels.utils import torch_experts
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -149,7 +150,7 @@ def pplx_cutlass_moe(
num_local_experts=num_local_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
activation=MoEActivation.SILU,
in_dtype=torch.bfloat16,
device="cuda",
routing_method=RoutingMethodType.Llama4,

View File

@@ -11,15 +11,11 @@ import pytest
import torch
from tests.kernels.moe.utils import make_dummy_moe_config
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FUSED_MOE_UNQUANTIZED_CONFIG,
)
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
from vllm.model_executor.layers.fused_moe.utils import (
GELU_NO_MUL,
RELU2_NO_MUL,
SILU_NO_MUL,
)
from vllm.platforms import current_platform
# Test parameters
@@ -28,7 +24,11 @@ N_SIZES = [128, 256]
K_SIZES = [64, 128]
TOPK_VALUES = [1, 2]
NUM_EXPERTS = 8
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
NO_MUL_ACTIVATIONS = [
MoEActivation.SILU_NO_MUL,
MoEActivation.GELU_NO_MUL,
MoEActivation.RELU2_NO_MUL,
]
def make_test_tensors(
@@ -73,7 +73,7 @@ def test_triton_experts_no_mul_activation(
n: int,
k: int,
topk: int,
activation: str,
activation: MoEActivation,
):
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
m, n, k, NUM_EXPERTS, topk
@@ -161,11 +161,11 @@ def test_workspace_shapes_no_mul_vs_gated():
)
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, SILU_NO_MUL
M, N, K, topk, 8, 8, None, MoEActivation.SILU_NO_MUL
)
ws1_gated, _, out_gated = experts.workspace_shapes(
M, N, K, topk, 8, 8, None, "silu"
M, N, K, topk, 8, 8, None, MoEActivation.SILU
)
# For no_mul: activation_out_dim = N
@@ -202,10 +202,10 @@ def test_adjust_n_for_activation():
N = 256
# Gated activations should return N // 2
assert experts.adjust_N_for_activation(N, "silu") == N // 2
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
assert experts.adjust_N_for_activation(N, MoEActivation.SILU) == N // 2
assert experts.adjust_N_for_activation(N, MoEActivation.GELU) == N // 2
# Non-gated activations should return N
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.SILU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.GELU_NO_MUL) == N
assert experts.adjust_N_for_activation(N, MoEActivation.RELU2_NO_MUL) == N

View File

@@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe import (
fused_experts,
fused_topk,
)
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -54,7 +55,7 @@ def make_dummy_moe_config(
num_local_experts=num_experts,
num_logical_experts=num_experts,
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
activation="silu",
activation=MoEActivation.SILU,
in_dtype=in_dtype,
device="cuda",
routing_method=RoutingMethodType.TopK,