[Kernel] Initial Machete W4A8 support + Refactors (#9855)
Signed-off-by: Lucas Wilkinson <lwilkinson@neuralmagic.com>
This commit is contained in:
@@ -79,7 +79,9 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
c.weight_type,
|
||||
packed_dim=0)
|
||||
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
|
||||
self.config.weight_type)
|
||||
a_type=c.act_type,
|
||||
b_type=c.weight_type,
|
||||
group_scales_type=c.act_type)
|
||||
return x
|
||||
|
||||
def transform_w_s(x):
|
||||
@@ -105,12 +107,12 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
if c.has_g_idx:
|
||||
x_2d = self.act_perm(x_2d)
|
||||
|
||||
output = ops.machete_gemm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_zeros=None,
|
||||
b_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
output = ops.machete_mm(a=x_2d,
|
||||
b_q=w_q,
|
||||
b_type=c.weight_type,
|
||||
b_group_zeros=None,
|
||||
b_group_scales=w_s,
|
||||
b_group_size=c.group_size)
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
Reference in New Issue
Block a user