[ROCm][Kernel] MoE weights padding (#14454)

Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: charlifu <charlifu@amd.com>
Co-authored-by: charlifu <charlifu@amd.com>
This commit is contained in:
Gregory Shtrasberg
2025-03-24 19:45:30 -04:00
committed by GitHub
parent 8279201ce6
commit f533b5837f
5 changed files with 65 additions and 16 deletions

View File

@@ -3,8 +3,11 @@
Run `pytest tests/kernels/test_moe.py`.
"""
import pytest
import torch
from torch.nn import Parameter
from torch.nn import functional as F
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
@@ -37,6 +40,7 @@ TOP_KS = [2, 6]
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
def test_fused_moe(
m: int,
n: int,
@@ -45,6 +49,7 @@ def test_fused_moe(
topk: int,
ep_size: int,
dtype: torch.dtype,
padding: bool,
):
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
@@ -65,16 +70,7 @@ def test_fused_moe(
else:
e_map = None
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch_output = torch_moe(a, w1, w2, score, topk, e_map)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
iterative_output = iterative_moe(a,
w1,
w2,
@@ -83,6 +79,23 @@ def test_fused_moe(
global_num_experts=e,
expert_map=e_map,
renormalize=False)
# Pad the weight if moe padding is enabled
if padding:
w1 = F.pad(w1, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
w2 = F.pad(w2, (0, 128), "constant", 0)[..., 0:-128]
torch.cuda.empty_cache()
triton_output = fused_moe(a,
w1,
w2,
score,
topk,
global_num_experts=e,
expert_map=e_map,
renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output,
torch_output,
atol=2e-2,
@@ -202,8 +215,9 @@ def test_fused_moe_wn16(m: int, n: int, k: int, e: int, topk: int,
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("padding", [True, False])
@torch.inference_mode()
def test_mixtral_moe(dtype: torch.dtype):
def test_mixtral_moe(dtype: torch.dtype, padding: bool):
"""Make sure our Mixtral MoE implementation agrees with the one from
huggingface."""
@@ -233,6 +247,17 @@ def test_mixtral_moe(dtype: torch.dtype):
# vLLM uses 1D query [num_tokens, hidden_dim]
vllm_inputs = hf_inputs.flatten(0, 1)
# Pad the weight if moe padding is enabled
if padding:
vllm_moe.experts.w13_weight = Parameter(F.pad(
vllm_moe.experts.w13_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
vllm_moe.experts.w2_weight = Parameter(F.pad(
vllm_moe.experts.w2_weight, (0, 128), "constant", 0)[..., 0:-128],
requires_grad=False)
torch.cuda.empty_cache()
# Run forward passes for both MoE blocks
hf_states, _ = hf_moe.forward(hf_inputs)
vllm_states = vllm_moe.forward(vllm_inputs)