[MoE Refactor] Move Test Impl into Test Dirs (#32129)

Signed-off-by: Robert Shaw <rshaw@neuralmagic.com>
Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
Robert Shaw
2026-01-17 23:16:59 -05:00
committed by GitHub
parent 4147910f1e
commit 4a6af8813f
3 changed files with 58 additions and 64 deletions

View File

@@ -37,9 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_topk,
modular_triton_fused_moe,
)
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
fused_moe as iterative_moe,
)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_permute_bias,
)
@@ -61,6 +58,64 @@ from vllm.scalar_type import ScalarType, scalar_types
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.worker.workspace import init_workspace_manager
def iterative_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
global_num_experts: int,
expert_map: torch.Tensor = None,
renormalize: bool = False,
) -> torch.Tensor:
"""
Baseline implementation of fused moe.
Args:
hidden_states: [*, hidden_size]
w1: [num_experts, intermediate_size * 2, hidden_size]
w2: [num_experts, hidden_size, intermediate_size]
gating_output: [*, num_experts]
expert_map: [num_experts]
"""
orig_shape = hidden_states.shape
hidden_size = hidden_states.shape[-1]
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
intermediate_size = w2.shape[-1]
dtype = hidden_states.dtype
hidden_states = hidden_states.view(num_tokens, hidden_size)
gating_output = gating_output.view(num_tokens, global_num_experts)
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = topk_weights.to(dtype)
if expert_map is not None:
selected_experts = expert_map[selected_experts]
final_hidden_states = None
for expert_idx in range(num_experts):
expert_w1 = w1[expert_idx]
expert_w2 = w2[expert_idx]
expert_mask = selected_experts == expert_idx
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
x = F.linear(hidden_states, expert_w1)
gate = F.silu(x[:, :intermediate_size])
x = x[:, intermediate_size:] * gate
x = F.linear(x, expert_w2)
current_hidden_states = x * expert_weights
if final_hidden_states is None:
final_hidden_states = current_hidden_states
else:
final_hidden_states = final_hidden_states + current_hidden_states
return final_hidden_states.view(orig_shape) # type: ignore
NUM_EXPERTS = [8, 64, 192]
NUM_EXPERTS_LARGE = [128, 256]
EP_SIZE = [1, 4]