[Refactor] Replace activation: str with MoEActivation enum (#33843)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
MoEActivation,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -1155,7 +1156,10 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
@pytest.mark.parametrize("m", [1, 64, 256])
|
||||
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
|
||||
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
|
||||
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
@pytest.mark.parametrize("activation", [MoEActivation.RELU2_NO_MUL])
|
||||
def test_fused_marlin_moe_non_gated(
|
||||
m: int, n: int, k: int, e: int, topk: int, activation: MoEActivation
|
||||
):
|
||||
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
|
||||
|
||||
Non-gated activations like relu2 don't have the gate-up projection pattern,
|
||||
@@ -1198,7 +1202,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
w2_data.w_ref,
|
||||
score,
|
||||
topk,
|
||||
activation="relu2",
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
@@ -1223,7 +1227,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
activation="relu2_no_mul",
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
|
||||
@@ -1330,9 +1334,18 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("with_bias", [False, True])
|
||||
@pytest.mark.parametrize("activation", ["silu"])
|
||||
@pytest.mark.parametrize("activation", [MoEActivation.SILU])
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
|
||||
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
|
||||
def test_cpu_fused_moe_basic(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
with_bias: bool,
|
||||
activation: MoEActivation,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
|
||||
|
||||
device = "cpu"
|
||||
@@ -1608,6 +1621,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
num_local_experts=e,
|
||||
num_logical_experts=e,
|
||||
activation="silu",
|
||||
device="cuda",
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
|
||||
Reference in New Issue
Block a user