[ROCm][Kernel] MoE weights padding (#14454)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: charlifu <charlifu@amd.com> Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
committed by
GitHub
parent
8279201ce6
commit
f533b5837f
@@ -3,8 +3,11 @@
|
||||
|
||||
Run `pytest tests/kernels/test_moe.py`.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
from torch.nn import functional as F
|
||||
from transformers import MixtralConfig
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
@@ -37,6 +40,7 @@ TOP_KS = [2, 6]
|
||||
@pytest.mark.parametrize("topk", TOP_KS)
|
||||
@pytest.mark.parametrize("ep_size", EP_SIZE)
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
def test_fused_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -45,6 +49,7 @@ def test_fused_moe(
|
||||
topk: int,
|
||||
ep_size: int,
|
||||
dtype: torch.dtype,
|
||||
padding: bool,
|
||||
):
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
@@ -65,16 +70,7 @@ def test_fused_moe(
|
||||
else:
|
||||
e_map = None
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
iterative_output = iterative_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
@@ -83,6 +79,23 @@ def test_fused_moe(
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
triton_output = fused_moe(a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
global_num_experts=e,
|
||||
expert_map=e_map,
|
||||
renormalize=False)
|
||||
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
|
||||
torch.testing.assert_close(iterative_output,
|
||||
torch_output,
|
||||
atol=2e-2,
|
||||
@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("padding", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_mixtral_moe(dtype: torch.dtype):
|
||||
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
|
||||
"""Make sure our Mixtral MoE implementation agrees with the one from
|
||||
huggingface."""
|
||||
|
||||
@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
|
||||
# vLLM uses 1D query [num_tokens, hidden_dim]
|
||||
vllm_inputs = hf_inputs.flatten(0, 1)
|
||||
|
||||
# Pad the weight if moe padding is enabled
|
||||
if padding:
|
||||
vllm_moe.experts.w13_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
vllm_moe.experts.w2_weight = Parameter(F.pad(
|
||||
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
|
||||
requires_grad=False)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Run forward passes for both MoE blocks
|
||||
hf_states, _ = hf_moe.forward(hf_inputs)
|
||||
vllm_states = vllm_moe.forward(vllm_inputs)
|
||||
|
||||
Reference in New Issue
Block a user