[Misc] Update compressed-tensors WNA16 to support zero-points (#14211)
This commit is contained in:
@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
|
||||
if c.has_g_idx and\
|
||||
c.partition_weight_shape[0] != c.full_weight_shape[0]:
|
||||
return False, "Act reordering currently not supported by Machete, "\
|
||||
"when the input features are partitioned across "\
|
||||
"devices"
|
||||
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by "\
|
||||
" Compressed Tensors + Machete. (Kernel supports it"\
|
||||
" but CompressedTensorsWNA16 does not so support has"\
|
||||
" not been added to MacheteWNA16Kernel yet"
|
||||
return False, "Zero points currently not supported by Machete"
|
||||
|
||||
if c.weight_type not in query_machete_supported_quant_types(
|
||||
c.zero_points):
|
||||
|
||||
@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
|
||||
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
|
||||
query_marlin_supported_quant_types)
|
||||
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
permute_param_layout_)
|
||||
|
||||
@@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
@classmethod
|
||||
def can_implement(cls,
|
||||
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
|
||||
if c.zero_points:
|
||||
return False, "Zero points currently not supported by "\
|
||||
" MarlinLinearKernel. Will be added when AWQMarlin "\
|
||||
"is migrated over to using MPLinearKernel backend"
|
||||
|
||||
quant_types = query_marlin_supported_quant_types(c.zero_points)
|
||||
if c.weight_type not in quant_types:
|
||||
@@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
if self.w_zp_name is None:
|
||||
self.w_zp_name = "w_zp"
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
pass
|
||||
# TODO (lucas): add the following when AWQMarlin is migrated over to
|
||||
# using MPLinearKernel backend
|
||||
# self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
# marlin_zero_points(
|
||||
# x,
|
||||
# size_k=c.partition_weight_shape[0],
|
||||
# size_n=c.partition_weight_shape[1],
|
||||
# num_bits=c.weight_type.size_bits))
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
|
||||
def transform_w_q(x):
|
||||
assert isinstance(x, BasevLLMParameter)
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
@@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
group_size=c.group_size)
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
||||
getattr(layer, self.w_gidx_name))
|
||||
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
||||
layer.g_idx_sort_indices = g_idx_sort_indices
|
||||
else:
|
||||
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if c.zero_points:
|
||||
grouped_k = (c.partition_weight_shape[0] //
|
||||
c.group_size if c.group_size != -1 else 1)
|
||||
self._transform_param(layer, self.w_zp_name, lambda x: \
|
||||
marlin_zero_points(
|
||||
unpack_cols(x.t(), c.weight_type.size_bits,
|
||||
grouped_k,
|
||||
c.partition_weight_shape[1]),
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits))
|
||||
else:
|
||||
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
||||
self._transform_param(layer, self.w_q_name, transform_w_q)
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
|
||||
@@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
wtype=c.weight_type,
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
has_zp=self.config.zero_points,
|
||||
is_k_full=self.is_k_full,
|
||||
bias=bias)
|
||||
|
||||
Reference in New Issue
Block a user