[Kernel][Quantization] add w4a8 support for marlin kernel (#24722)

Signed-off-by: Jinzhen Lin <jinzhen.ljz@antgroup.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin@redhat.com>
This commit is contained in:
Jinzhen Lin
2025-11-29 23:19:33 +08:00
committed by GitHub
parent fa59fe417f
commit 1656ad3704
46 changed files with 4371 additions and 2240 deletions

View File

@@ -41,6 +41,8 @@ from vllm.model_executor.layers.quantization.utils.gptq_utils import (
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_workspace_new,
marlin_moe_permute_scales,
marlin_permute_bias,
@@ -251,8 +253,21 @@ class GPTQMarlinConfig(QuantizationConfig):
return MoeWNA16Config.from_config(self.full_config).get_quant_method(
layer, prefix
)
return get_moe_quant_method(self, layer, prefix, GPTQMarlinMoEMethod)
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
moe_quant_method = get_moe_quant_method(
self, layer, prefix, GPTQMarlinMoEMethod
)
if moe_quant_method is None:
return None
moe_quant_method.input_dtype = get_marlin_input_dtype(prefix)
return moe_quant_method
quant_method = get_linear_quant_method(
self, layer, prefix, GPTQMarlinLinearMethod
)
if quant_method is None:
return None
quant_method.input_dtype = get_marlin_input_dtype(prefix)
return quant_method
@classmethod
def is_gptq_marlin_compatible(cls, quant_config: dict[str, Any]):
@@ -319,6 +334,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
self.input_dtype = None
self.quant_type = self.quant_config.quant_type
# Verify supported on platform.
verify_marlin_supported(
@@ -339,6 +356,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
input_dtype = self.input_dtype
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
@@ -347,7 +365,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
act_type=params_dtype if input_dtype is None else input_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act,
@@ -482,6 +500,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
self.quant_type = scalar_types.uint8b128
else:
raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.")
self.input_dtype = None
self.use_marlin = True
def create_weights(
@@ -493,6 +512,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
):
layer.input_dtype = self.input_dtype
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
self.is_k_full = (not self.quant_config.desc_act) or (
@@ -513,6 +540,9 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
scales_size2 = 1
strategy = FusedMoeWeightScaleSupported.CHANNEL.value
layer.num_groups_w13 = scales_size13
layer.num_groups_w2 = scales_size2
extra_weight_attrs.update({"quant_method": strategy, "is_transposed": True})
# Fused gate_up_proj (column parallel)
w13_qweight = torch.nn.Parameter(
@@ -630,6 +660,19 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.workspace = marlin_make_workspace_new(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if is_a_8bit:
assert self.quant_type == scalar_types.uint4b8, (
"W8A8-INT8 is not supported by marlin kernel."
)
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.w13_qweight, inplace=True)
ops.marlin_int4_fp8_preprocess(layer.w2_qweight, inplace=True)
layer.w13_scales.data = layer.w13_scales.data * 512
layer.w2_scales.data = layer.w2_scales.data * 512
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
@@ -678,6 +721,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
@@ -686,6 +730,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
@@ -694,7 +739,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_k=layer.intermediate_size_per_partition,
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w13 > 1:
marlin_w13_scales, w13_input_global_scale = marlin_act_int8_process_scales(
marlin_w13_scales
)
layer.register_parameter(
"w13_input_global_scale",
torch.nn.Parameter(w13_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
@@ -706,7 +761,17 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
),
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups_w2 > 1:
marlin_w2_scales, w2_input_global_scale = marlin_act_int8_process_scales(
marlin_w2_scales
)
layer.register_parameter(
"w2_input_global_scale",
torch.nn.Parameter(w2_input_global_scale, requires_grad=False),
)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
if hasattr(layer, "w13_bias") and layer.w13_bias is not None:
@@ -761,6 +826,8 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
router_logits,
topk_weights,
topk_ids,
input_global_scale1=getattr(layer, "w13_input_global_scale", None),
input_global_scale2=getattr(layer, "w2_input_global_scale", None),
quant_type_id=self.quant_type.id,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
@@ -771,4 +838,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
sort_indices2=layer.w2_g_idx_sort_indices,
workspace=layer.workspace,
is_k_full=self.is_k_full,
input_dtype=self.input_dtype,
)