[MoE Refactor] Move Test Impl into Test Dirs (#32129)
Signed-off-by: Robert Shaw <rshaw@neuralmagic.com> Co-authored-by: Robert Shaw <rshaw@neuralmagic.com>
This commit is contained in:
@@ -37,9 +37,6 @@ from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_topk,
|
||||
modular_triton_fused_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
|
||||
fused_moe as iterative_moe,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
marlin_permute_bias,
|
||||
)
|
||||
@@ -61,6 +58,64 @@ from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.worker.workspace import init_workspace_manager
|
||||
|
||||
|
||||
def iterative_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
renormalize: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Baseline implementation of fused moe.
|
||||
|
||||
Args:
|
||||
hidden_states: [*, hidden_size]
|
||||
w1: [num_experts, intermediate_size * 2, hidden_size]
|
||||
w2: [num_experts, hidden_size, intermediate_size]
|
||||
gating_output: [*, num_experts]
|
||||
expert_map: [num_experts]
|
||||
"""
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_size = hidden_states.shape[-1]
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
num_experts = w1.shape[0]
|
||||
intermediate_size = w2.shape[-1]
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
hidden_states = hidden_states.view(num_tokens, hidden_size)
|
||||
gating_output = gating_output.view(num_tokens, global_num_experts)
|
||||
topk_weights = gating_output.softmax(dim=-1, dtype=torch.float)
|
||||
topk_weights, selected_experts = topk_weights.topk(topk, dim=-1)
|
||||
if renormalize:
|
||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||
topk_weights = topk_weights.to(dtype)
|
||||
|
||||
if expert_map is not None:
|
||||
selected_experts = expert_map[selected_experts]
|
||||
|
||||
final_hidden_states = None
|
||||
for expert_idx in range(num_experts):
|
||||
expert_w1 = w1[expert_idx]
|
||||
expert_w2 = w2[expert_idx]
|
||||
expert_mask = selected_experts == expert_idx
|
||||
expert_weights = (topk_weights * expert_mask).sum(dim=-1, keepdim=True)
|
||||
x = F.linear(hidden_states, expert_w1)
|
||||
gate = F.silu(x[:, :intermediate_size])
|
||||
x = x[:, intermediate_size:] * gate
|
||||
x = F.linear(x, expert_w2)
|
||||
current_hidden_states = x * expert_weights
|
||||
if final_hidden_states is None:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states = final_hidden_states + current_hidden_states
|
||||
|
||||
return final_hidden_states.view(orig_shape) # type: ignore
|
||||
|
||||
|
||||
NUM_EXPERTS = [8, 64, 192]
|
||||
NUM_EXPERTS_LARGE = [128, 256]
|
||||
EP_SIZE = [1, 4]
|
||||
|
||||
Reference in New Issue
Block a user