[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)

Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
ElizaWszola
2024-09-16 17:47:19 +02:00
committed by GitHub
parent fc990f9795
commit a091e2da3e
12 changed files with 452 additions and 184 deletions

View File

@@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref):
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("num_bits", [4, 8])
def test_fused_marlin_moe(
m: int,
n: int,
@@ -148,6 +149,7 @@ def test_fused_marlin_moe(
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
torch.manual_seed(7)
@@ -161,13 +163,12 @@ def test_fused_marlin_moe(
if group_size in (k, n):
return
quant_type = scalar_types.uint4b8
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
for i in range(w2.shape[0]):
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
w_ref1_l = []
qweight1_l = []
@@ -240,6 +241,7 @@ def test_fused_marlin_moe(
topk_ids,
w1_scale=scales1,
w2_scale=scales2,
num_bits=num_bits,
)
assert compute_max_diff(marlin_output, triton_output) < 4e-2
@@ -254,7 +256,8 @@ def test_fused_marlin_moe(
@pytest.mark.parametrize("topk", [2, 6])
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
@pytest.mark.parametrize("act_order", [True, False])
def test_marlin_moe_mmm(
@pytest.mark.parametrize("num_bits", [4, 8])
def test_single_marlin_moe_multiply(
m: int,
n: int,
k: int,
@@ -262,6 +265,7 @@ def test_marlin_moe_mmm(
topk: int,
group_size: int,
act_order: bool,
num_bits: int,
):
if topk > e:
return
@@ -273,7 +277,8 @@ def test_marlin_moe_mmm(
if group_size == k:
return
quant_type = scalar_types.uint4b8
quant_type = (scalar_types.uint4b8
if num_bits == 4 else scalar_types.uint8b128)
dtype = torch.float16
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
@@ -308,7 +313,8 @@ def test_marlin_moe_mmm(
g_idx,
sort_indices,
topk,
renormalize=False)
renormalize=False,
num_bits=num_bits)
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
assert compute_max_diff(marlin_output, torch_output) < 1e-2