[Bugfix] Fix Marlin MoE act order when is_k_full == False (#8741)
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref):
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -154,6 +155,7 @@ def test_fused_marlin_moe(
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
is_k_full: bool,
|
||||
):
|
||||
seed_everything(7)
|
||||
|
||||
@@ -166,6 +168,9 @@ def test_fused_marlin_moe(
|
||||
return
|
||||
if group_size in (k, n):
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
quant_type = (scalar_types.uint4b8
|
||||
if num_bits == 4 else scalar_types.uint8b128)
|
||||
@@ -246,6 +251,7 @@ def test_fused_marlin_moe(
|
||||
w1_scale=scales1,
|
||||
w2_scale=scales2,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
|
||||
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||
@@ -290,6 +296,7 @@ def test_fused_marlin_moe(
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
@pytest.mark.parametrize("is_k_full", [True, False])
|
||||
def test_single_marlin_moe_multiply(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply(
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
is_k_full: bool,
|
||||
):
|
||||
if topk > e:
|
||||
return
|
||||
@@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply(
|
||||
return
|
||||
if group_size == k:
|
||||
return
|
||||
else:
|
||||
if not is_k_full:
|
||||
return
|
||||
|
||||
quant_type = (scalar_types.uint4b8
|
||||
if num_bits == 4 else scalar_types.uint8b128)
|
||||
@@ -339,15 +350,18 @@ def test_single_marlin_moe_multiply(
|
||||
sort_indices = stack_and_dev(sort_indices_l)
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
marlin_output = single_marlin_moe(a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
topk,
|
||||
renormalize=False,
|
||||
num_bits=num_bits)
|
||||
marlin_output = single_marlin_moe(
|
||||
a,
|
||||
qweight,
|
||||
scales,
|
||||
score,
|
||||
g_idx,
|
||||
sort_indices,
|
||||
topk,
|
||||
renormalize=False,
|
||||
num_bits=num_bits,
|
||||
is_k_full=is_k_full,
|
||||
)
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user