diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 18216b596..448816c28 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -92,7 +92,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | -| iterative | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_torch_iterative.fused_moe] | | rocm aiter moe | standard | fp8 | G(128),A,T | silu, gelu | Y | N | [`rocm_aiter_fused_experts`][vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe.rocm_aiter_fused_experts] | | cpu_fused_moe | standard | N/A | N/A | silu | N | N | [`CPUFusedMOE`][vllm.model_executor.layers.fused_moe.cpu_fused_moe.CPUFusedMOE] | | naive batched4 | batched | int8,
fp8 | G,A,T | silu, gelu | 6 | Y | [`NaiveBatchedExperts`][vllm.model_executor.layers.fused_moe.fused_batched_moe.NaiveBatchedExperts] | diff --git a/tests/kernels/moe/test_moe.py b/tests/kernels/moe/test_moe.py index 07ced9769..91d3c119b 100644 --- a/tests/kernels/moe/test_moe.py +++ b/tests/kernels/moe/test_moe.py @@ -37,9 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, modular_triton_fused_moe, ) -from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( - fused_moe as iterative_moe, -) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( marlin_permute_bias, ) @@ -61,6 +58,64 @@ from vllm.scalar_type import ScalarType, scalar_types from vllm.utils.torch_utils import set_random_seed from vllm.v1.worker.workspace import init_workspace_manager + +def iterative_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + global_num_experts: int, + expert_map: torch.Tensor = None, + renormalize: bool = False, +) -> torch.Tensor: + """ + Baseline implementation of fused moe. + + Args: + hidden_states: [*, hidden_size] + w1: [num_experts, intermediate_size * 2, hidden_size] + w2: [num_experts, hidden_size, intermediate_size] + gating_output: [*, num_experts] + expert_map: [num_experts] + """ + orig_shape = hidden_states.shape + hidden_size = hidden_states.shape[-1] + num_tokens = hidden_states.shape[:-1].numel() + num_experts = w1.shape[0] + intermediate_size = w2.shape[-1] + dtype = hidden_states.dtype + + hidden_states = hidden_states.view(num_tokens, hidden_size) + gating_output = gating_output.view(num_tokens, global_num_experts) + topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) + topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + topk_weights = topk_weights.to(dtype) + + if expert_map is not None: + selected_experts = expert_map[selected_experts] + + final_hidden_states = None + for expert_idx in range(num_experts): + expert_w1 = w1[expert_idx] + expert_w2 = w2[expert_idx] + expert_mask = selected_experts == expert_idx + expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) + x = F.linear(hidden_states, expert_w1) + gate = F.silu(x[:, :intermediate_size]) + x = x[:, intermediate_size:] * gate + x = F.linear(x, expert_w2) + current_hidden_states = x * expert_weights + if final_hidden_states is None: + final_hidden_states = current_hidden_states + else: + final_hidden_states = final_hidden_states + current_hidden_states + + return final_hidden_states.view(orig_shape) # type: ignore + + NUM_EXPERTS = [8, 64, 192] NUM_EXPERTS_LARGE = [128, 256] EP_SIZE = [1, 4] diff --git a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py b/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py deleted file mode 100644 index f721d00d7..000000000 --- a/vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +++ /dev/null @@ -1,60 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch -import torch.nn.functional as F - - -def fused_moe( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - topk: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - renormalize: bool = False, -) -> torch.Tensor: - """ - Args: - hidden_states: [*, hidden_size] - w1: [num_experts, intermediate_size * 2, hidden_size] - w2: [num_experts, hidden_size, intermediate_size] - gating_output: [*, num_experts] - expert_map: [num_experts] - """ - orig_shape = hidden_states.shape - hidden_size = hidden_states.shape[-1] - num_tokens = hidden_states.shape[:-1].numel() - num_experts = w1.shape[0] - intermediate_size = w2.shape[-1] - dtype = hidden_states.dtype - - hidden_states = hidden_states.view(num_tokens, hidden_size) - gating_output = gating_output.view(num_tokens, global_num_experts) - topk_weights = gating_output.softmax(dim=-1, dtype=torch.float) - topk_weights, selected_experts = topk_weights.topk(topk, dim=-1) - if renormalize: - topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) - topk_weights = topk_weights.to(dtype) - - if expert_map is not None: - selected_experts = expert_map[selected_experts] - - final_hidden_states = None - for expert_idx in range(num_experts): - expert_w1 = w1[expert_idx] - expert_w2 = w2[expert_idx] - expert_mask = selected_experts == expert_idx - expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True) - x = F.linear(hidden_states, expert_w1) - gate = F.silu(x[:, :intermediate_size]) - x = x[:, intermediate_size:] * gate - x = F.linear(x, expert_w2) - current_hidden_states = x * expert_weights - if final_hidden_states is None: - final_hidden_states = current_hidden_states - else: - final_hidden_states = final_hidden_states + current_hidden_states - - return final_hidden_states.view(orig_shape) # type: ignore