[Refactor] Replace activation: str with MoEActivation enum (#33843)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -54,7 +55,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int,
|
||||
@@ -63,7 +64,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
@@ -73,7 +74,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_blocksize = 16
|
||||
is_gated_act = activation == "silu_and_mul"
|
||||
is_gated_act = activation.is_gated
|
||||
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
@@ -112,15 +113,13 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=fi_activation,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
|
||||
Reference in New Issue
Block a user