[CPU] Refactor CPU fused MOE (#30531)
Signed-off-by: jiang1.li <jiang1.li@intel.com>
This commit is contained in:
172
tests/kernels/moe/test_cpu_fused_moe.py
Normal file
172
tests/kernels/moe/test_cpu_fused_moe.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
|
||||
from vllm._custom_ops import cpu_fused_moe, cpu_prepack_moe_weight
|
||||
from vllm.model_executor.layers.activation import SiluAndMul, SwigluOAIAndMul
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_cpu():
|
||||
pytest.skip("skipping CPU-only tests", allow_module_level=True)
|
||||
|
||||
EXPERT_NUM = [
|
||||
8,
|
||||
]
|
||||
HIDDEN_DIM = [128, 2880]
|
||||
INTERMEDIATE_DIM = [128, 2880]
|
||||
BATCH_SIZE = [1, 64, 256]
|
||||
ACT = ["silu", "swigluoai"]
|
||||
USE_BIAS = [True, False]
|
||||
ISA = ["amx", "vec"] if torch._C._cpu._is_amx_tile_supported() else ["vec"]
|
||||
DTYPE = [torch.bfloat16]
|
||||
|
||||
_CPU_MOE_ACT = {
|
||||
"silu": SiluAndMul(),
|
||||
"swigluoai": SwigluOAIAndMul(),
|
||||
}
|
||||
|
||||
|
||||
def ref_fused_moe(
|
||||
input: torch.Tensor,
|
||||
w13: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w13_bias: torch.Tensor | None,
|
||||
w2_bias: torch.Tensor | None,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
) -> torch.Tensor:
|
||||
len_experts = w13.size(0)
|
||||
|
||||
cnts = topk_ids.new_zeros((topk_ids.shape[0], len_experts))
|
||||
cnts.scatter_(1, topk_ids.to(torch.int64), 1)
|
||||
tokens_per_expert = cnts.sum(dim=0)
|
||||
idxs = topk_ids.view(-1).argsort()
|
||||
|
||||
sorted_tokens = input[idxs // topk_ids.shape[1]]
|
||||
tokens_per_expert = tokens_per_expert.cpu().numpy()
|
||||
|
||||
outputs = []
|
||||
start_idx = 0
|
||||
|
||||
for i, num_tokens in enumerate(tokens_per_expert):
|
||||
end_idx = start_idx + num_tokens
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
tokens_for_this_expert = sorted_tokens[start_idx:end_idx].float()
|
||||
curr_w13 = w13[i].float()
|
||||
curr_w2 = w2[i].float()
|
||||
|
||||
curr_w13_bias = None
|
||||
if w13_bias is not None:
|
||||
curr_w13_bias = w13_bias[i].float()
|
||||
|
||||
curr_w2_bias = None
|
||||
if w2_bias is not None:
|
||||
curr_w2_bias = w2_bias[i].float()
|
||||
|
||||
gate_up = torch.nn.functional.linear(
|
||||
tokens_for_this_expert, curr_w13, curr_w13_bias
|
||||
)
|
||||
# Note: to simulate the kernel implementation
|
||||
gate_up = (
|
||||
_CPU_MOE_ACT[activation]
|
||||
.forward_native(gate_up)
|
||||
.to(dtype=input.dtype)
|
||||
.float()
|
||||
)
|
||||
expert_out = torch.nn.functional.linear(gate_up, curr_w2, curr_w2_bias)
|
||||
|
||||
outputs.append(expert_out)
|
||||
start_idx = end_idx
|
||||
|
||||
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
|
||||
new_x = torch.empty_like(outs)
|
||||
|
||||
new_x[idxs] = outs
|
||||
final_out = (
|
||||
new_x.view(*topk_ids.shape, -1)
|
||||
.mul_(topk_weights.unsqueeze(dim=-1))
|
||||
.sum(dim=1)
|
||||
.type(input.dtype)
|
||||
)
|
||||
return final_out
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", BATCH_SIZE)
|
||||
@pytest.mark.parametrize("expert_num", EXPERT_NUM)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_DIM)
|
||||
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_DIM)
|
||||
@pytest.mark.parametrize("use_bias", USE_BIAS)
|
||||
@pytest.mark.parametrize("dtype", DTYPE)
|
||||
@pytest.mark.parametrize("act", ACT)
|
||||
@pytest.mark.parametrize("isa", ISA)
|
||||
def test_cpu_fused_moe(
|
||||
batch_size: int,
|
||||
expert_num: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
use_bias: bool,
|
||||
dtype: torch.dtype,
|
||||
act: str,
|
||||
isa: str,
|
||||
):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
topk_num = max(expert_num // 2, 1)
|
||||
up_dim = 2 * intermediate_size
|
||||
|
||||
input = torch.randn((batch_size, hidden_size), dtype=dtype) / (
|
||||
0.5 * hidden_size**0.5
|
||||
)
|
||||
w13 = torch.randn((expert_num, up_dim, hidden_size), dtype=dtype) / (
|
||||
0.5 * hidden_size**0.5
|
||||
)
|
||||
w2 = torch.randn((expert_num, hidden_size, intermediate_size), dtype=dtype) / (
|
||||
0.5 * intermediate_size**0.5
|
||||
)
|
||||
router_logits = torch.randn((batch_size, expert_num), dtype=dtype)
|
||||
w13_bias = None
|
||||
w2_bias = None
|
||||
if use_bias:
|
||||
w13_bias = torch.randn((expert_num, up_dim), dtype=dtype) / (0.5 * up_dim**0.5)
|
||||
w2_bias = torch.randn((expert_num, hidden_size), dtype=dtype) / (
|
||||
0.5 * hidden_size**0.5
|
||||
)
|
||||
score = torch.softmax(router_logits, dim=-1, dtype=torch.float32)
|
||||
topk_weight, topk_ids = torch.topk(score, topk_num)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
|
||||
ref_output = ref_fused_moe(
|
||||
input,
|
||||
w13,
|
||||
w2,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
act,
|
||||
)
|
||||
|
||||
packed_w13 = cpu_prepack_moe_weight(w13, isa)
|
||||
packed_w2 = cpu_prepack_moe_weight(w2, isa)
|
||||
output = cpu_fused_moe(
|
||||
input,
|
||||
packed_w13,
|
||||
packed_w2,
|
||||
w13_bias,
|
||||
w2_bias,
|
||||
topk_weight,
|
||||
topk_ids,
|
||||
act,
|
||||
isa,
|
||||
)
|
||||
|
||||
atol, rtol = get_default_atol(output), get_default_rtol(output)
|
||||
(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
Reference in New Issue
Block a user