[Refactor] Replace activation: str with MoEActivation enum (#33843)

Signed-off-by: mgoin <mgoin64@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-02-11 20:29:32 -05:00
committed by GitHub
parent 83b47f67b1
commit ff1f83b056
48 changed files with 474 additions and 282 deletions

View File

@@ -13,6 +13,7 @@ from tests.kernels.utils import torch_moe
from vllm import _custom_ops as ops
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEParallelConfig,
@@ -54,7 +55,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
@torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph(
m: int,
@@ -63,7 +64,7 @@ def test_flashinfer_fp4_moe_no_graph(
e: int,
topk: int,
dtype: torch.dtype,
activation: str,
activation: MoEActivation,
workspace_init,
):
set_random_seed(7)
@@ -73,7 +74,7 @@ def test_flashinfer_fp4_moe_no_graph(
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
quant_blocksize = 16
is_gated_act = activation == "silu_and_mul"
is_gated_act = activation.is_gated
w1_q, w2_q, quant_config = make_test_quant_config(
e,
@@ -112,15 +113,13 @@ def test_flashinfer_fp4_moe_no_graph(
inplace=False,
)
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
flashinfer_output = flashinfer_experts(
hidden_states=a,
w1=w1_q,
w2=w2_q,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=fi_activation,
activation=activation,
)
# Reference check: