[Kernel] Enable 8-bit weights in Fused Marlin MoE (#8032)
Co-authored-by: Dipika <dipikasikka1@gmail.com>
This commit is contained in:
@@ -140,6 +140,7 @@ def compute_max_diff(output, output_ref):
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
def test_fused_marlin_moe(
|
||||
m: int,
|
||||
n: int,
|
||||
@@ -148,6 +149,7 @@ def test_fused_marlin_moe(
|
||||
topk: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
):
|
||||
torch.manual_seed(7)
|
||||
|
||||
@@ -161,13 +163,12 @@ def test_fused_marlin_moe(
|
||||
if group_size in (k, n):
|
||||
return
|
||||
|
||||
quant_type = scalar_types.uint4b8
|
||||
quant_type = (scalar_types.uint4b8
|
||||
if num_bits == 4 else scalar_types.uint8b128)
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
|
||||
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
|
||||
for i in range(w2.shape[0]):
|
||||
w2[0] = torch.eye(k, n, device="cuda", dtype=dtype)
|
||||
|
||||
w_ref1_l = []
|
||||
qweight1_l = []
|
||||
@@ -240,6 +241,7 @@ def test_fused_marlin_moe(
|
||||
topk_ids,
|
||||
w1_scale=scales1,
|
||||
w2_scale=scales2,
|
||||
num_bits=num_bits,
|
||||
)
|
||||
|
||||
assert compute_max_diff(marlin_output, triton_output) < 4e-2
|
||||
@@ -254,7 +256,8 @@ def test_fused_marlin_moe(
|
||||
@pytest.mark.parametrize("topk", [2, 6])
|
||||
@pytest.mark.parametrize("group_size", [-1, 32, 64, 128])
|
||||
@pytest.mark.parametrize("act_order", [True, False])
|
||||
def test_marlin_moe_mmm(
|
||||
@pytest.mark.parametrize("num_bits", [4, 8])
|
||||
def test_single_marlin_moe_multiply(
|
||||
m: int,
|
||||
n: int,
|
||||
k: int,
|
||||
@@ -262,6 +265,7 @@ def test_marlin_moe_mmm(
|
||||
topk: int,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
num_bits: int,
|
||||
):
|
||||
if topk > e:
|
||||
return
|
||||
@@ -273,7 +277,8 @@ def test_marlin_moe_mmm(
|
||||
if group_size == k:
|
||||
return
|
||||
|
||||
quant_type = scalar_types.uint4b8
|
||||
quant_type = (scalar_types.uint4b8
|
||||
if num_bits == 4 else scalar_types.uint8b128)
|
||||
dtype = torch.float16
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
|
||||
w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10
|
||||
@@ -308,7 +313,8 @@ def test_marlin_moe_mmm(
|
||||
g_idx,
|
||||
sort_indices,
|
||||
topk,
|
||||
renormalize=False)
|
||||
renormalize=False,
|
||||
num_bits=num_bits)
|
||||
torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk)
|
||||
|
||||
assert compute_max_diff(marlin_output, torch_output) < 1e-2
|
||||
|
||||
Reference in New Issue
Block a user