From 410d30089310c3e6bc385230f2e1bb2ccdc72edc Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 23 Mar 2026 08:36:08 +0100 Subject: [PATCH] [ROCm][Refactor] Enable AWQMarlinConfig on ROCm to use choose_mp_linear_kernel (#36505) Signed-off-by: Matthias Gehre Co-authored-by: Michael Goin --- .../kernels/linear/mixed_precision/conch.py | 2 + .../layers/quantization/awq_marlin.py | 226 +++++++++++------- .../layers/quantization/utils/marlin_utils.py | 17 +- 3 files changed, 143 insertions(+), 102 deletions(-) diff --git a/vllm/model_executor/kernels/linear/mixed_precision/conch.py b/vllm/model_executor/kernels/linear/mixed_precision/conch.py index 82dd32da1..cd371581b 100644 --- a/vllm/model_executor/kernels/linear/mixed_precision/conch.py +++ b/vllm/model_executor/kernels/linear/mixed_precision/conch.py @@ -113,6 +113,8 @@ class ConchLinearKernel(MPLinearKernel): self._transform_param(layer, self.w_s_name, transform_w_s) if self.config.zero_points: self._transform_param(layer, self.w_zp_name, transform_w_zp) + elif self.w_zp_name is not None: + layer.register_parameter(self.w_zp_name, None) def apply_weights( self, diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 5b7af3193..0350a2a8b 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -10,6 +10,10 @@ from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops from vllm.logger import init_logger +from vllm.model_executor.kernels.linear import ( + MPLinearLayerConfig, + choose_mp_linear_kernel, +) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -34,21 +38,16 @@ from vllm.model_executor.layers.quantization.base_config import ( ) from vllm.model_executor.layers.quantization.utils import replace_parameter from vllm.model_executor.layers.quantization.utils.marlin_utils import ( - apply_awq_marlin_linear, - awq_to_marlin_zero_points, check_marlin_supported, check_marlin_supports_layer, check_moe_marlin_supports_layer, get_marlin_input_dtype, marlin_act_int8_process_scales, - marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_moe_permute_scales, marlin_permute_bias, - marlin_permute_scales, moe_awq_to_marlin_zero_points, verify_marlin_supported, - verify_marlin_supports_shape, ) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead @@ -63,6 +62,90 @@ if TYPE_CHECKING: logger = init_logger(__name__) +# AWQ uses a non-standard packing order within int32 values. +# For 4-bit: standard order stores values at bit positions [0,4,8,12,16,20,24,28] +# for indices [0,1,2,3,4,5,6,7], while AWQ stores them for indices +# [0,4,1,5,2,6,3,7]. This permutation reverses that ordering. +_REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + + +def _convert_awq_to_standard_format( + layer: torch.nn.Module, + w_q_name: str, + w_zp_name: str, + size_bits: int, +) -> None: + """Convert AWQ weight and zero-point tensors to standard GPTQ-like format. + + AWQ packs qweight along the output dim with a non-standard bit order. + This converts to standard bit order and repacks qweight along the input + dim, matching the format expected by the MPLinearKernel framework. + """ + pack_factor = 32 // size_bits + mask = (1 << size_bits) - 1 + device = getattr(layer, w_q_name).device + reverse_order = torch.tensor( + _REVERSE_AWQ_PACK_ORDER, dtype=torch.long, device=device + ) + shifts = torch.arange(0, 32, size_bits, dtype=torch.int32, device=device) + + # --- Convert qweight: (K, N // pack) packed_dim=1 → (K // pack, N) packed_dim=0 + qw = getattr(layer, w_q_name).data + K, N_packed = qw.shape + N = N_packed * pack_factor + + # Unpack int32 → individual values, fix AWQ ordering + unpacked = (qw.unsqueeze(-1) >> shifts) & mask # (K, N_packed, pack_factor) + unpacked = unpacked[:, :, reverse_order] + unpacked = unpacked.reshape(K, N) # (K, N) + + # Repack along input dim (dim 0) + unpacked = unpacked.reshape(K // pack_factor, pack_factor, N) + new_qw = (unpacked.to(torch.int32) << shifts[None, :, None]).sum( + dim=1, dtype=torch.int32 + ) + + def _noop_loader(*args, **kwargs): + pass + + new_param = PackedvLLMParameter( + data=new_qw.contiguous(), + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=pack_factor, + weight_loader=_noop_loader, + ) + setattr(layer, w_q_name, new_param) + + # --- Convert qzeros: fix AWQ bit ordering and repack + # AWQ qzeros: (G, N // pack) packed along dim 1, AWQ bit order + # Target: (N // pack, G) packed along dim 0, standard bit order + # This matches the CompressedTensors layout expected by the kernels. + qz = getattr(layer, w_zp_name).data + G, _ = qz.shape + + unpacked_zp = (qz.unsqueeze(-1) >> shifts) & mask # (G, N_packed, pack_factor) + unpacked_zp = unpacked_zp[:, :, reverse_order] + unpacked_zp = unpacked_zp.reshape(G, N) # (G, N) individual values + + # Transpose and repack along dim 0 (output dim) + unpacked_zp = unpacked_zp.T # (N, G) + unpacked_zp = unpacked_zp.reshape(N // pack_factor, pack_factor, G) + new_qz = (unpacked_zp.to(torch.int32) << shifts[None, :, None]).sum( + dim=1, dtype=torch.int32 + ) + + new_zp_param = PackedvLLMParameter( + data=new_qz.contiguous(), + output_dim=0, + input_dim=1, + packed_dim=0, + packed_factor=pack_factor, + weight_loader=_noop_loader, + ) + setattr(layer, w_zp_name, new_zp_param) + class AWQMarlinConfig(QuantizationConfig): """Config class for AWQ Marlin""" @@ -226,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig): group_size = quant_config.get("group_size") zero_point = quant_config.get("zero_point") - if not current_platform.is_cuda(): + if not (current_platform.is_cuda_alike() or current_platform.is_cpu()): return False if quant_method != "awq": @@ -268,15 +351,26 @@ class AWQMarlinConfig(QuantizationConfig): class AWQMarlinLinearMethod(LinearMethodBase): """Linear method for AWQ Marlin. + Uses choose_mp_linear_kernel to select the best available kernel + (Conch, Exllama, or Marlin) for the current platform. + Args: quant_config: The AWQ Marlin quantization config. """ + _kernel_backends_being_used: set[str] = set() + def __init__(self, quant_config: AWQMarlinConfig) -> None: self.quant_config = quant_config self.quant_type = scalar_types.uint4 self.input_dtype = None + verify_marlin_supported( + quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size, + has_zp=self.quant_config.zero_point, + ) + def create_weights( self, layer: torch.nn.Module, @@ -287,23 +381,35 @@ class AWQMarlinLinearMethod(LinearMethodBase): params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: - del output_size output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") - # Normalize group_size if self.quant_config.group_size != -1: group_size = self.quant_config.group_size else: group_size = input_size - verify_marlin_supports_shape( - output_size_per_partition=output_size_per_partition, - input_size_per_partition=input_size_per_partition, - input_size=input_size, - group_size=group_size, + mp_linear_kernel_config = MPLinearLayerConfig( + full_weight_shape=(input_size, output_size), + partition_weight_shape=( + input_size_per_partition, + output_size_per_partition, + ), + weight_type=self.quant_config.quant_type, + act_type=params_dtype if self.input_dtype is None else self.input_dtype, + group_size=self.quant_config.group_size, + zero_points=self.quant_config.zero_point, + has_g_idx=False, ) + kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for AWQMarlinLinearMethod", kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # Weights are loaded in AWQ checkpoint format (packed along output dim). + # Conversion to GPTQ-like format happens in process_weights_after_loading. qweight = PackedvLLMParameter( data=torch.empty( input_size_per_partition, @@ -318,7 +424,6 @@ class AWQMarlinLinearMethod(LinearMethodBase): ) num_groups = input_size_per_partition // group_size - layer.num_groups = num_groups qzeros = PackedvLLMParameter( data=torch.empty( @@ -348,73 +453,22 @@ class AWQMarlinLinearMethod(LinearMethodBase): layer.register_parameter("qzeros", qzeros) layer.register_parameter("scales", scales) - layer.input_size_per_partition = input_size_per_partition - layer.output_size_per_partition = output_size_per_partition - layer.num_groups = num_groups + self.kernel = kernel_type( + mp_linear_kernel_config, + w_q_param_name="qweight", + w_s_param_name="scales", + w_zp_param_name="qzeros", + ) - # TODO: Update this docs - # Checkpoints are serialized in AutoAWQ format, which is different from the - # marlin format. This function is called after the weights are loaded. - # Here, we handle the repacking def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - device = layer.qweight.device - layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False) - layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False) - layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False) - - # Allocate marlin workspace - layer.workspace = marlin_make_workspace_new(device) - - is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1 - - if self.input_dtype == torch.float8_e4m3fn: - ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True) - layer.scales.data = layer.scales.data * 512 - - # Repack weights from AWQ format to marlin format. - marlin_qweight = ops.awq_marlin_repack( - layer.qweight, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits, - is_a_8bit=is_a_8bit, + # AWQ checkpoints use a non-standard packing order and pack qweight + # along the output dimension. Convert to the standard format + # (GPTQ-like: standard bit order, qweight packed along input dim) + # before handing off to the kernel. + _convert_awq_to_standard_format( + layer, "qweight", "qzeros", self.quant_config.quant_type.size_bits ) - replace_parameter(layer, "qweight", marlin_qweight) - - # Permute scales from AWQ format to marlin format. - marlin_scales = marlin_permute_scales( - layer.scales, - size_k=layer.input_size_per_partition, - size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size, - is_a_8bit=is_a_8bit, - ) - if self.input_dtype == torch.int8 and layer.num_groups > 1: - marlin_scales, input_global_scale = marlin_act_int8_process_scales( - marlin_scales - ) - layer.register_parameter( - "input_global_scale", Parameter(input_global_scale, requires_grad=False) - ) - - replace_parameter(layer, "scales", marlin_scales) - - # Permute zero-points from AWQ format to marlin format. - marlin_zp = awq_to_marlin_zero_points( - layer.qzeros, - size_k=layer.num_groups, - size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits, - is_a_8bit=is_a_8bit, - ) - replace_parameter(layer, "qzeros", marlin_zp) - - # Not-used - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) - - if hasattr(layer, "bias") and layer.bias is not None: - layer.bias.data = marlin_permute_bias(layer.bias) + self.kernel.process_weights_after_loading(layer) def apply( self, @@ -422,21 +476,7 @@ class AWQMarlinLinearMethod(LinearMethodBase): x: torch.Tensor, bias: torch.Tensor | None = None, ) -> torch.Tensor: - return apply_awq_marlin_linear( - input=x, - weight=layer.qweight, - weight_scale=layer.scales, - weight_zp=layer.qzeros, - g_idx=layer.g_idx, - g_idx_sort_indices=layer.g_idx_sort_indices, - workspace=layer.workspace, - quant_type=self.quant_config.quant_type, - output_size_per_partition=layer.output_size_per_partition, - input_size_per_partition=layer.input_size_per_partition, - input_global_scale=getattr(layer, "input_global_scale", None), - bias=bias, - input_dtype=self.input_dtype, - ) + return self.kernel.apply_weights(layer, x, bias) class AWQMarlinMoEMethod(FusedMoEMethodBase): diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 23ccfc536..d659effd7 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -46,14 +46,15 @@ def query_marlin_supported_quant_types( if current_platform.is_cpu(): return _query_cpu_marlin_supported_quant_types(has_zp, include_fp_type) - if device_capability is None: - capability_tuple = current_platform.get_device_capability() - device_capability = ( - -1 if capability_tuple is None else capability_tuple.to_int() - ) + if not current_platform.is_rocm(): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = ( + -1 if capability_tuple is None else capability_tuple.to_int() + ) - if device_capability < 75: - return [] + if device_capability < 75: + return [] # - has_zp is True: return quant_types that has zero points # - has_zp is False: return quant_types that has not zero points @@ -210,8 +211,6 @@ def check_marlin_supports_shape( def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool: - if current_platform.is_rocm(): - return False output_size_per_partition = ( getattr(layer, "output_size_per_partition", None) or layer.output_size )