[Misc] Update compressed-tensors WNA16 to support zero-points (#14211)

This commit is contained in:
Dipika Sikka
2025-04-15 09:33:51 -04:00
committed by GitHub
parent 280d62b8a2
commit 54a66e5fee
6 changed files with 85 additions and 45 deletions

View File

@@ -26,17 +26,14 @@ class MacheteLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by "\
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
return False, "Zero points currently not supported by Machete"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):

View File

@@ -9,7 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
marlin_zero_points, query_marlin_supported_quant_types, unpack_cols)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
@@ -25,10 +25,6 @@ class MarlinLinearKernel(MPLinearKernel):
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
@@ -67,28 +63,6 @@ class MarlinLinearKernel(MPLinearKernel):
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
@@ -108,6 +82,28 @@ class MarlinLinearKernel(MPLinearKernel):
group_size=c.group_size)
return x
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
grouped_k = (c.partition_weight_shape[0] //
c.group_size if c.group_size != -1 else 1)
self._transform_param(layer, self.w_zp_name, lambda x: \
marlin_zero_points(
unpack_cols(x.t(), c.weight_type.size_bits,
grouped_k,
c.partition_weight_shape[1]),
size_k=grouped_k,
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
@@ -131,5 +127,6 @@ class MarlinLinearKernel(MPLinearKernel):
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
has_zp=self.config.zero_points,
is_k_full=self.is_k_full,
bias=bias)