[Refactor] Replace activation: str with MoEActivation enum (#33843)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT_FN
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
@@ -19,7 +20,7 @@ EXPERT_NUM = [
|
||||
HIDDEN_DIM = [128, 2880]
|
||||
INTERMEDIATE_DIM = [128, 2880]
|
||||
BATCH_SIZE = [1, 64, 256]
|
||||
ACT = ["silu", "swigluoai"]
|
||||
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
|
||||
USE_BIAS = [True, False]
|
||||
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
DTYPE = [torch.bfloat16]
|
||||
@@ -33,7 +34,7 @@ def ref_fused_moe(
|
||||
w2_bias: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> torch.Tensor:
|
||||
len_experts = w13.size(0)
|
||||
|
||||
@@ -103,7 +104,7 @@ def test_cpu_fused_moe(
|
||||
intermediate_size: int,
|
||||
use_bias: bool,
|
||||
dtype: torch.dtype,
|
||||
act: str,
|
||||
act: MoEActivation,
|
||||
isa: str,
|
||||
):
|
||||
set_random_seed(0)
|
||||
@@ -153,7 +154,7 @@ def test_cpu_fused_moe(
|
||||
w2_bias,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
act,
|
||||
act.value,
|
||||
isa,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user