[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)
Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
@@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
check_marlin_supported,
|
||||
check_moe_marlin_supports_layer,
|
||||
get_marlin_input_dtype,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
marlin_permute_bias,
|
||||
@@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
||||
layer, prefix
|
||||
)
|
||||
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod)
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
||||
moe_quant_method = get_moe_quant_method(
|
||||
self, layer, prefix, GPTQMarlinMoEMethod
|
||||
)
|
||||
if moe_quant_method is None:
|
||||
return None
|
||||
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return moe_quant_method
|
||||
|
||||
quant_method = get_linear_quant_method(
|
||||
self, layer, prefix, GPTQMarlinLinearMethod
|
||||
)
|
||||
if quant_method is None:
|
||||
return None
|
||||
quant_method.input_dtype = get_marlin_input_dtype(prefix)
|
||||
return quant_method
|
||||
|
||||
@classmethod
|
||||
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
|
||||
@@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.input_dtype = None
|
||||
self.quant_type = self.quant_config.quant_type
|
||||
|
||||
# Verify supported on platform.
|
||||
verify_marlin_supported(
|
||||
@@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
is_row_parallel = input_size != input_size_per_partition
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
input_dtype = self.input_dtype
|
||||
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
@@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype,
|
||||
act_type=params_dtype if input_dtype is None else input_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=False,
|
||||
has_g_idx=self.quant_config.desc_act,
|
||||
@@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self.quant_type = scalar_types.uint8b128
|
||||
else:
|
||||
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
|
||||
self.input_dtype = None
|
||||
self.use_marlin = True
|
||||
|
||||
def create_weights(
|
||||
@@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
layer.input_dtype = self.input_dtype
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert self.quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
||||
|
||||
self.is_k_full = (not self.quant_config.desc_act) or (
|
||||
@@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
scales_size2 = 1
|
||||
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
|
||||
|
||||
layer.num_groups_w13 = scales_size13
|
||||
layer.num_groups_w2 = scales_size2
|
||||
|
||||
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
|
||||
# Fused gate_up_proj (column parallel)
|
||||
w13_qweight = torch.nn.Parameter(
|
||||
@@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.workspace = marlin_make_workspace_new(device, 4)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if is_a_8bit:
|
||||
assert self.quant_type == scalar_types.uint4b8, (
|
||||
"W8A8-INT8 is not supported by marlin kernel."
|
||||
)
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
|
||||
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
|
||||
layer.w13_scales.data = layer.w13_scales.data * 512
|
||||
layer.w2_scales.data = layer.w2_scales.data * 512
|
||||
|
||||
# Process act_order
|
||||
if self.quant_config.desc_act:
|
||||
# Get sorting based on g_idx
|
||||
@@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w13_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
||||
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
||||
@@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
||||
layer.w2_qweight.shape[2],
|
||||
self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
||||
# Repack scales
|
||||
@@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
size_k=layer.intermediate_size_per_partition,
|
||||
size_n=layer.w13_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
|
||||
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w13_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w13_input_global_scale",
|
||||
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w13_scales", marlin_w13_scales)
|
||||
marlin_w2_scales = marlin_moe_permute_scales(
|
||||
s=layer.w2_scales,
|
||||
@@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
),
|
||||
size_n=layer.w2_scales.shape[2],
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
|
||||
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_w2_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"w2_input_global_scale",
|
||||
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
|
||||
)
|
||||
|
||||
replace_parameter(layer, "w2_scales", marlin_w2_scales)
|
||||
|
||||
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
|
||||
@@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
router_logits,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
|
||||
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
|
||||
quant_type_id=self.quant_type.id,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
@@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
sort_indices2=layer.w2_g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
is_k_full=self.is_k_full,
|
||||
input_dtype=self.input_dtype,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user