[Misc] Fused MoE Marlin support for GPTQ (#8217)

This commit is contained in:
Dipika Sikka
2024-09-09 23:02:52 -04:00
committed by GitHub
parent c7cb5c3335
commit 6cd5e5b07e
19 changed files with 912 additions and 204 deletions

View File

@@ -1,5 +1,5 @@
"""This file is used for /tests and /benchmarks"""
from typing import List
from typing import List, Optional
import numpy
import torch
@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
def permute_rows(q_w: torch.Tensor,
w_ref: torch.Tensor,
group_size: int,
test_perm: Optional[torch.Tensor] = None):
assert q_w.shape == w_ref.shape
orig_device = q_w.device
@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx[i] = i // group_size
# Simulate act_order by doing a random permutation on K
rand_perm = torch.randperm(k_size)
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous()
@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
)
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
group_size: int, act_order: bool):
def gptq_quantize_weights(w: torch.Tensor,
quant_type: ScalarType,
group_size: int,
act_order: bool,
test_perm: Optional[torch.Tensor] = None):
size_k, _ = w.shape
assert w.is_floating_point(), "w must be float"
@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
test_perm)
return w_ref, w_q, w_s, g_idx, rand_perm