Enable ZP Support for Machete (#20268)

Signed-off-by: czhu-cohere <conway.zhu@cohere.com>
This commit is contained in:
czhu-cohere
2025-07-01 00:12:20 -07:00
committed by GitHub
parent 22e9d42040
commit 9909726d2a
3 changed files with 19 additions and 5 deletions

View File

@@ -33,8 +33,6 @@ class MacheteLinearKernel(MPLinearKernel):
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by Machete"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
@@ -53,6 +51,7 @@ class MacheteLinearKernel(MPLinearKernel):
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
# `weight_zp` is: {input_dim = 0, output_dim = 1, packed_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
@@ -90,16 +89,29 @@ class MacheteLinearKernel(MPLinearKernel):
x.data = x.data.contiguous()
return x
def transform_w_zp(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=1)
x_unpacked = unpack_quantized_values_into_int32(x.data,
c.weight_type,
packed_dim=1)
w_s = getattr(layer, self.w_s_name).data
# pre-apply scales to zero-points
x.data = (-1.0 * w_s * (x_unpacked.to(w_s.dtype))).contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
if c.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
w_q, w_s, w_zp, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
@@ -110,7 +122,7 @@ class MacheteLinearKernel(MPLinearKernel):
output = ops.machete_mm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_group_zeros=None,
b_group_zeros=w_zp,
b_group_scales=w_s,
b_group_size=c.group_size)