[Misc] Fused MoE Marlin support for GPTQ (#8217)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
"""This file is used for /tests and /benchmarks"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
@@ -53,7 +53,10 @@ def get_pack_factor(num_bits):
|
||||
return 32 // num_bits
|
||||
|
||||
|
||||
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
def permute_rows(q_w: torch.Tensor,
|
||||
w_ref: torch.Tensor,
|
||||
group_size: int,
|
||||
test_perm: Optional[torch.Tensor] = None):
|
||||
assert q_w.shape == w_ref.shape
|
||||
|
||||
orig_device = q_w.device
|
||||
@@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
|
||||
g_idx[i] = i // group_size
|
||||
|
||||
# Simulate act_order by doing a random permutation on K
|
||||
rand_perm = torch.randperm(k_size)
|
||||
rand_perm = test_perm if test_perm is not None else torch.randperm(k_size)
|
||||
|
||||
g_idx = g_idx[rand_perm].contiguous()
|
||||
q_w = q_w[rand_perm, :].contiguous()
|
||||
@@ -164,8 +167,11 @@ def quantize_weights(w: torch.Tensor,
|
||||
)
|
||||
|
||||
|
||||
def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
||||
group_size: int, act_order: bool):
|
||||
def gptq_quantize_weights(w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
act_order: bool,
|
||||
test_perm: Optional[torch.Tensor] = None):
|
||||
size_k, _ = w.shape
|
||||
|
||||
assert w.is_floating_point(), "w must be float"
|
||||
@@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType,
|
||||
), "For act_order, groupsize = {} must be less than size_k = {}".format(
|
||||
group_size, size_k)
|
||||
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size)
|
||||
w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size,
|
||||
test_perm)
|
||||
|
||||
return w_ref, w_q, w_s, g_idx, rand_perm
|
||||
|
||||
|
||||
Reference in New Issue
Block a user