[Refactor] Replace activation: str with MoEActivation enum (#33843)
Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -22,6 +22,7 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
maybe_make_prepare_finalize,
|
||||
)
|
||||
@@ -599,7 +600,7 @@ def make_modular_kernel(
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
activation="silu",
|
||||
activation=MoEActivation.SILU,
|
||||
device=vllm_config.device_config.device,
|
||||
routing_method=RoutingMethodType.DeepSeekV3,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import _CPU_MOE_ACT_FN
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
@@ -19,7 +20,7 @@ EXPERT_NUM = [
|
||||
HIDDEN_DIM = [128, 2880]
|
||||
INTERMEDIATE_DIM = [128, 2880]
|
||||
BATCH_SIZE = [1, 64, 256]
|
||||
ACT = ["silu", "swigluoai"]
|
||||
ACT = [MoEActivation.SILU, MoEActivation.SWIGLUOAI]
|
||||
USE_BIAS = [True, False]
|
||||
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
DTYPE = [torch.bfloat16]
|
||||
@@ -33,7 +34,7 @@ def ref_fused_moe(
|
||||
w2_bias: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
) -> torch.Tensor:
|
||||
len_experts = w13.size(0)
|
||||
|
||||
@@ -103,7 +104,7 @@ def test_cpu_fused_moe(
|
||||
intermediate_size: int,
|
||||
use_bias: bool,
|
||||
dtype: torch.dtype,
|
||||
act: str,
|
||||
act: MoEActivation,
|
||||
isa: str,
|
||||
):
|
||||
set_random_seed(0)
|
||||
@@ -153,7 +154,7 @@ def test_cpu_fused_moe(
|
||||
w2_bias,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
act,
|
||||
act.value,
|
||||
isa,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_experts, fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -531,7 +532,7 @@ def test_run_cutlass_moe_fp8(
|
||||
c_strides1 = torch.full((e,), 2 * n, device="cuda", dtype=torch.int64)
|
||||
c_strides2 = torch.full((e,), k, device="cuda", dtype=torch.int64)
|
||||
|
||||
activation = "silu"
|
||||
activation = MoEActivation.SILU
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
mt.a, mt.a_scale, torch.float8_e4m3fn, per_act_token
|
||||
)
|
||||
|
||||
@@ -16,6 +16,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
@@ -324,7 +325,7 @@ def deepep_deepgemm_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
activation="silu",
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
|
||||
@@ -15,6 +15,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
)
|
||||
@@ -260,7 +261,7 @@ def deep_ep_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights_chunk,
|
||||
topk_ids=topk_chunk,
|
||||
activation="silu",
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
|
||||
@@ -7,6 +7,7 @@ import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -93,9 +94,14 @@ class TestData:
|
||||
|
||||
@staticmethod
|
||||
def make_moe_tensors_8bit(
|
||||
m: int, k: int, n: int, e: int, is_trtllm: bool, activation: str = "silu"
|
||||
m: int,
|
||||
k: int,
|
||||
n: int,
|
||||
e: int,
|
||||
is_trtllm: bool,
|
||||
activation: MoEActivation = MoEActivation.SILU,
|
||||
) -> "TestData":
|
||||
is_gated = activation != "relu2_no_mul"
|
||||
is_gated = activation.is_gated
|
||||
|
||||
hidden_states = torch.randn((m, k), device="cuda", dtype=torch.bfloat16) / 10
|
||||
w13 = torch.randn(
|
||||
@@ -194,7 +200,7 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
activation=MoEActivation.SILU,
|
||||
global_num_experts=e,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=True,
|
||||
@@ -219,21 +225,19 @@ def test_flashinfer_per_tensor_moe_fp8_no_graph(
|
||||
@pytest.mark.parametrize("m,n,k", MNK_FACTORS)
|
||||
@pytest.mark.parametrize("e", NUM_EXPERTS)
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("activation", ["silu", "relu2_no_mul"])
|
||||
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
|
||||
def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
monkeypatch,
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
monkeypatch.setenv("VLLM_FUSED_MOE_CHUNK_SIZE", "8192")
|
||||
assert activation in ["silu", "relu2_no_mul"]
|
||||
is_act_and_mul = activation == "silu_and_mul"
|
||||
with set_current_vllm_config(vllm_config):
|
||||
td = TestData.make_moe_tensors_8bit(
|
||||
m, k, n, e, is_trtllm=False, activation=activation
|
||||
@@ -292,7 +296,7 @@ def test_flashinfer_cutlass_moe_fp8_no_graph(
|
||||
device="cuda",
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
in_dtype=torch.bfloat16,
|
||||
is_act_and_mul=is_act_and_mul,
|
||||
is_act_and_mul=activation.is_gated,
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from tests.kernels.utils import torch_moe
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -54,7 +55,7 @@ MNK_FACTORS = [
|
||||
@pytest.mark.parametrize("e", [40, 64, 256])
|
||||
@pytest.mark.parametrize("topk", [1, 6, 8])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("activation", ["silu_and_mul", "relu2"])
|
||||
@pytest.mark.parametrize("activation", [MoEActivation.SILU, MoEActivation.RELU2_NO_MUL])
|
||||
@torch.inference_mode()
|
||||
def test_flashinfer_fp4_moe_no_graph(
|
||||
m: int,
|
||||
@@ -63,7 +64,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
workspace_init,
|
||||
):
|
||||
set_random_seed(7)
|
||||
@@ -73,7 +74,7 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
|
||||
quant_blocksize = 16
|
||||
is_gated_act = activation == "silu_and_mul"
|
||||
is_gated_act = activation.is_gated
|
||||
|
||||
w1_q, w2_q, quant_config = make_test_quant_config(
|
||||
e,
|
||||
@@ -112,15 +113,13 @@ def test_flashinfer_fp4_moe_no_graph(
|
||||
inplace=False,
|
||||
)
|
||||
|
||||
fi_activation = {"silu_and_mul": "silu", "relu2": "relu2_no_mul"}[activation]
|
||||
|
||||
flashinfer_output = flashinfer_experts(
|
||||
hidden_states=a,
|
||||
w1=w1_q,
|
||||
w2=w2_q,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=fi_activation,
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
# Reference check:
|
||||
|
||||
@@ -7,6 +7,7 @@ Test modular OAI Triton MoE
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.utils.import_utils import has_triton_kernels
|
||||
|
||||
if not has_triton_kernels():
|
||||
@@ -192,7 +193,7 @@ def oai_triton_moe_impl(
|
||||
w2=w2,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation="swigluoai",
|
||||
activation=MoEActivation.SWIGLUOAI,
|
||||
global_num_experts=num_experts,
|
||||
expert_map=None,
|
||||
apply_router_weight_on_input=False,
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import (
|
||||
MoEActivation,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
@@ -1155,7 +1156,10 @@ def test_fused_marlin_moe_with_bias(m):
|
||||
@pytest.mark.parametrize("m", [1, 64, 256])
|
||||
@pytest.mark.parametrize("n,k", [(1024, 1024), (2048, 2048)])
|
||||
@pytest.mark.parametrize("e,topk", [(8, 2), (64, 4)])
|
||||
def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
@pytest.mark.parametrize("activation", [MoEActivation.RELU2_NO_MUL])
|
||||
def test_fused_marlin_moe_non_gated(
|
||||
m: int, n: int, k: int, e: int, topk: int, activation: MoEActivation
|
||||
):
|
||||
"""Test Marlin MoE with non-gated activation (relu2_no_mul).
|
||||
|
||||
Non-gated activations like relu2 don't have the gate-up projection pattern,
|
||||
@@ -1198,7 +1202,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
w2_data.w_ref,
|
||||
score,
|
||||
topk,
|
||||
activation="relu2",
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
marlin_output = fused_marlin_moe(
|
||||
@@ -1223,7 +1227,7 @@ def test_fused_marlin_moe_non_gated(m: int, n: int, k: int, e: int, topk: int):
|
||||
w2_zeros=w2_data.zeros,
|
||||
quant_type_id=quant_type.id,
|
||||
is_k_full=is_k_full,
|
||||
activation="relu2_no_mul",
|
||||
activation=activation,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(marlin_output, torch_output, atol=1e-1, rtol=0)
|
||||
@@ -1330,9 +1334,18 @@ def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
|
||||
@pytest.mark.parametrize("topk", [2])
|
||||
@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
|
||||
@pytest.mark.parametrize("with_bias", [False, True])
|
||||
@pytest.mark.parametrize("activation", ["silu"])
|
||||
@pytest.mark.parametrize("activation", [MoEActivation.SILU])
|
||||
@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only test")
|
||||
def test_cpu_fused_moe_basic(m, n, k, e, topk, dtype, with_bias, activation):
|
||||
def test_cpu_fused_moe_basic(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
e: int,
|
||||
topk: int,
|
||||
dtype: torch.dtype,
|
||||
with_bias: bool,
|
||||
activation: MoEActivation,
|
||||
):
|
||||
from vllm.model_executor.layers.fused_moe.cpu_fused_moe import CPUFusedMOE
|
||||
|
||||
device = "cpu"
|
||||
@@ -1608,6 +1621,7 @@ def test_unquantized_bf16_flashinfer_trtllm_backend(
|
||||
hidden_dim=k,
|
||||
intermediate_size_per_partition=n,
|
||||
num_local_experts=e,
|
||||
num_logical_experts=e,
|
||||
activation="silu",
|
||||
device="cuda",
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
|
||||
@@ -9,6 +9,7 @@ from tests.kernels.utils import torch_experts
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -149,7 +150,7 @@ def pplx_cutlass_moe(
|
||||
num_local_experts=num_local_experts,
|
||||
num_logical_experts=num_experts,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
activation="silu",
|
||||
activation=MoEActivation.SILU,
|
||||
in_dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
routing_method=RoutingMethodType.Llama4,
|
||||
|
||||
@@ -11,15 +11,11 @@ import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.moe.utils import make_dummy_moe_config
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FUSED_MOE_UNQUANTIZED_CONFIG,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
GELU_NO_MUL,
|
||||
RELU2_NO_MUL,
|
||||
SILU_NO_MUL,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# Test parameters
|
||||
@@ -28,7 +24,11 @@ N_SIZES = [128, 256]
|
||||
K_SIZES = [64, 128]
|
||||
TOPK_VALUES = [1, 2]
|
||||
NUM_EXPERTS = 8
|
||||
NO_MUL_ACTIVATIONS = [SILU_NO_MUL, GELU_NO_MUL, RELU2_NO_MUL]
|
||||
NO_MUL_ACTIVATIONS = [
|
||||
MoEActivation.SILU_NO_MUL,
|
||||
MoEActivation.GELU_NO_MUL,
|
||||
MoEActivation.RELU2_NO_MUL,
|
||||
]
|
||||
|
||||
|
||||
def make_test_tensors(
|
||||
@@ -73,7 +73,7 @@ def test_triton_experts_no_mul_activation(
|
||||
n: int,
|
||||
k: int,
|
||||
topk: int,
|
||||
activation: str,
|
||||
activation: MoEActivation,
|
||||
):
|
||||
hidden_states, w1, w2, topk_weights, topk_ids = make_test_tensors(
|
||||
m, n, k, NUM_EXPERTS, topk
|
||||
@@ -161,11 +161,11 @@ def test_workspace_shapes_no_mul_vs_gated():
|
||||
)
|
||||
|
||||
ws1_no_mul, _, out_no_mul = experts.workspace_shapes(
|
||||
M, N, K, topk, 8, 8, None, SILU_NO_MUL
|
||||
M, N, K, topk, 8, 8, None, MoEActivation.SILU_NO_MUL
|
||||
)
|
||||
|
||||
ws1_gated, _, out_gated = experts.workspace_shapes(
|
||||
M, N, K, topk, 8, 8, None, "silu"
|
||||
M, N, K, topk, 8, 8, None, MoEActivation.SILU
|
||||
)
|
||||
|
||||
# For no_mul: activation_out_dim = N
|
||||
@@ -202,10 +202,10 @@ def test_adjust_n_for_activation():
|
||||
N = 256
|
||||
|
||||
# Gated activations should return N // 2
|
||||
assert experts.adjust_N_for_activation(N, "silu") == N // 2
|
||||
assert experts.adjust_N_for_activation(N, "gelu") == N // 2
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.SILU) == N // 2
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.GELU) == N // 2
|
||||
|
||||
# Non-gated activations should return N
|
||||
assert experts.adjust_N_for_activation(N, SILU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, GELU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, RELU2_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.SILU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.GELU_NO_MUL) == N
|
||||
assert experts.adjust_N_for_activation(N, MoEActivation.RELU2_NO_MUL) == N
|
||||
|
||||
@@ -12,6 +12,7 @@ from vllm.model_executor.layers.fused_moe import (
|
||||
fused_experts,
|
||||
fused_topk,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.activation import MoEActivation
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEParallelConfig,
|
||||
@@ -54,7 +55,7 @@ def make_dummy_moe_config(
|
||||
num_local_experts=num_experts,
|
||||
num_logical_experts=num_experts,
|
||||
moe_parallel_config=FusedMoEParallelConfig.make_no_parallel(),
|
||||
activation="silu",
|
||||
activation=MoEActivation.SILU,
|
||||
in_dtype=in_dtype,
|
||||
device="cuda",
|
||||
routing_method=RoutingMethodType.TopK,
|
||||
|
||||
Reference in New Issue
Block a user