diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 18216b596..448816c28 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| marlin | standard,batched | 3 / N/A | 3 / N/A | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
-| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] |
| rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] |
| cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] |
| naive batched4 | batched | int8,fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] |
diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py
index 07ced9769..91d3c119b 100644
--- a/tests/kernels/moe/test_moe.py
+++ b/tests/kernels/moe/test_moe.py
@@ -37,9 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
)
-from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
- fused_moe as iterative_moe,
-)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias,
)
@@ -61,6 +58,64 @@ from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager
+
+def iterative_moe(
+ hidden_states: torch.Tensor,
+ w1: torch.Tensor,
+ w2: torch.Tensor,
+ gating_output: torch.Tensor,
+ topk: int,
+ global_num_experts: int,
+ expert_map: torch.Tensor = None,
+ renormalize: bool = False,
+) -> torch.Tensor:
+ """
+ Baseline implementation of fused moe.
+
+ Args:
+ hidden_states: [*, hidden_size]
+ w1: [num_experts, intermediate_size * 2, hidden_size]
+ w2: [num_experts, hidden_size, intermediate_size]
+ gating_output: [*, num_experts]
+ expert_map: [num_experts]
+ """
+ orig_shape = hidden_states.shape
+ hidden_size = hidden_states.shape[-1]
+ num_tokens = hidden_states.shape[:-1].numel()
+ num_experts = w1.shape[0]
+ intermediate_size = w2.shape[-1]
+ dtype = hidden_states.dtype
+
+ hidden_states = hidden_states.view(num_tokens, hidden_size)
+ gating_output = gating_output.view(num_tokens, global_num_experts)
+ topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
+ topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
+ if renormalize:
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+ topk_weights = topk_weights.to(dtype)
+
+ if expert_map is not None:
+ selected_experts = expert_map[selected_experts]
+
+ final_hidden_states = None
+ for expert_idx in range(num_experts):
+ expert_w1 = w1[expert_idx]
+ expert_w2 = w2[expert_idx]
+ expert_mask = selected_experts == expert_idx
+ expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
+ x = F.linear(hidden_states, expert_w1)
+ gate = F.silu(x[:, :intermediate_size])
+ x = x[:, intermediate_size:] * gate
+ x = F.linear(x, expert_w2)
+ current_hidden_states = x * expert_weights
+ if final_hidden_states is None:
+ final_hidden_states = current_hidden_states
+ else:
+ final_hidden_states = final_hidden_states + current_hidden_states
+
+ return final_hidden_states.view(orig_shape) # type: ignore
+
+
NUM_EXPERTS = [8, 64, 192]
NUM_EXPERTS_LARGE = [128, 256]
EP_SIZE = [1, 4]
diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
deleted file mode 100644
index f721d00d7..000000000
--- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py
+++ /dev/null
@@ -1,60 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import torch
-import torch.nn.functional as F
-
-
-def fused_moe(
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- gating_output: torch.Tensor,
- topk: int,
- global_num_experts: int,
- expert_map: torch.Tensor = None,
- renormalize: bool = False,
-) -> torch.Tensor:
- """
- Args:
- hidden_states: [*, hidden_size]
- w1: [num_experts, intermediate_size * 2, hidden_size]
- w2: [num_experts, hidden_size, intermediate_size]
- gating_output: [*, num_experts]
- expert_map: [num_experts]
- """
- orig_shape = hidden_states.shape
- hidden_size = hidden_states.shape[-1]
- num_tokens = hidden_states.shape[:-1].numel()
- num_experts = w1.shape[0]
- intermediate_size = w2.shape[-1]
- dtype = hidden_states.dtype
-
- hidden_states = hidden_states.view(num_tokens, hidden_size)
- gating_output = gating_output.view(num_tokens, global_num_experts)
- topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
- topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- topk_weights = topk_weights.to(dtype)
-
- if expert_map is not None:
- selected_experts = expert_map[selected_experts]
-
- final_hidden_states = None
- for expert_idx in range(num_experts):
- expert_w1 = w1[expert_idx]
- expert_w2 = w2[expert_idx]
- expert_mask = selected_experts == expert_idx
- expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
- x = F.linear(hidden_states, expert_w1)
- gate = F.silu(x[:, :intermediate_size])
- x = x[:, intermediate_size:] * gate
- x = F.linear(x, expert_w2)
- current_hidden_states = x * expert_weights
- if final_hidden_states is None:
- final_hidden_states = current_hidden_states
- else:
- final_hidden_states = final_hidden_states + current_hidden_states
-
- return final_hidden_states.view(orig_shape) # type: ignore