[Bugfix] Fix Marlin MoE act order when is_k_full == False (#8741)

Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
ElizaWszola
2024-09-29 03:19:40 +02:00
committed by GitHub
parent 5bf8789b2a
commit d081da0064
4 changed files with 37 additions and 18 deletions

View File

@@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_fused_marlin_moe(
m: int,
n: int,
@@ -154,6 +155,7 @@ def test_fused_marlin_moe(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
seed_everything(7)
@@ -166,6 +168,9 @@ def test_fused_marlin_moe(
return
if group_size in (k, n):
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
@@ -246,6 +251,7 @@ def test_fused_marlin_moe(
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
is_k_full=is_k_full,
)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
@@ -290,6 +296,7 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
@pytest.mark.parametrize("is_k_full", [True, False])
def test_single_marlin_moe_multiply(
m: int,
n: int,
@@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply(
group_size: int,
act_order: bool,
num_bits: int,
is_k_full: bool,
):
if topk > e:
return
@@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply(
return
if group_size == k:
return
else:
if not is_k_full:
return
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
@@ -339,15 +350,18 @@ def test_single_marlin_moe_multiply(
sort_indices = stack_and_dev(sort_indices_l)
score = torch.randn((m, e), device="cuda", dtype=dtype)
marlin_output = single_marlin_moe(a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits)
marlin_output = single_marlin_moe(
a,
qweight,
scales,
score,
g_idx,
sort_indices,
topk,
renormalize=False,
num_bits=num_bits,
is_k_full=is_k_full,
)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2