[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -9,6 +9,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
apply_gptq_marlin_linear,
|
||||
check_marlin_supports_shape,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_is_k_full,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
@@ -21,6 +22,7 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
)
|
||||
from vllm.model_executor.parameter import BasevLLMParameter, permute_param_layout_
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
|
||||
|
||||
@@ -65,6 +67,18 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = getattr(layer, self.w_q_name).device
|
||||
c = self.config
|
||||
is_a_8bit = c.act_type is not None and c.act_type.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert c.weight_type == scalar_types.uint4b8, (
|
||||
"W8A8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if c.act_type == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(getattr(layer, self.w_q_name), inplace=True)
|
||||
getattr(layer, self.w_s_name).data = (
|
||||
getattr(layer, self.w_s_name).data * 512
|
||||
)
|
||||
|
||||
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
||||
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
||||
@@ -88,6 +102,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
return x
|
||||
|
||||
@@ -99,7 +114,22 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
size_k=c.partition_weight_shape[0],
|
||||
size_n=c.partition_weight_shape[1],
|
||||
group_size=c.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
|
||||
if c.group_size == -1:
|
||||
num_groups = 1
|
||||
else:
|
||||
num_groups = c.partition_weight_shape[0] // c.group_size
|
||||
|
||||
if c.act_type == torch.int8 and num_groups > 1:
|
||||
x.data, input_global_scale = marlin_act_int8_process_scales(x.data)
|
||||
layer.register_parameter(
|
||||
"input_global_scale",
|
||||
torch.nn.Parameter(input_global_scale, requires_grad=False),
|
||||
)
|
||||
else:
|
||||
layer.input_global_scale = None
|
||||
return x
|
||||
|
||||
if c.has_g_idx:
|
||||
@@ -129,6 +159,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
size_k=grouped_k,
|
||||
size_n=c.partition_weight_shape[1],
|
||||
num_bits=c.weight_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
),
|
||||
)
|
||||
else:
|
||||
@@ -150,6 +181,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
|
||||
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
||||
# None for marlin
|
||||
|
||||
return apply_gptq_marlin_linear(
|
||||
input=x,
|
||||
weight=w_q,
|
||||
@@ -162,5 +194,7 @@ class MarlinLinearKernel(MPLinearKernel):
|
||||
input_size_per_partition=c.partition_weight_shape[0],
|
||||
output_size_per_partition=c.partition_weight_shape[1],
|
||||
is_k_full=self.is_k_full,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=c.act_type,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user