[ROCm][Refactor] Enable AWQMarlinConfig on ROCm to use choose_mp_linear_kernel (#36505)

Signed-off-by: Matthias Gehre <matthias.gehre@amd.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Matthias Gehre
2026-03-23 08:36:08 +01:00
committed by GitHub
parent d3fe857135
commit 410d300893
3 changed files with 143 additions and 102 deletions

View File

@@ -113,6 +113,8 @@ class ConchLinearKernel(MPLinearKernel):
self._transform_param(layer, self.w_s_name, transform_w_s)
if self.config.zero_points:
self._transform_param(layer, self.w_zp_name, transform_w_zp)
elif self.w_zp_name is not None:
layer.register_parameter(self.w_zp_name, None)
def apply_weights(
self,

View File

@@ -10,6 +10,10 @@ from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.kernels.linear import (
MPLinearLayerConfig,
choose_mp_linear_kernel,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -34,21 +38,16 @@ from vllm.model_executor.layers.quantization.base_config import (
)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_awq_marlin_linear,
awq_to_marlin_zero_points,
check_marlin_supported,
check_marlin_supports_layer,
check_moe_marlin_supports_layer,
get_marlin_input_dtype,
marlin_act_int8_process_scales,
marlin_make_empty_g_idx,
marlin_make_workspace_new,
marlin_moe_permute_scales,
marlin_permute_bias,
marlin_permute_scales,
moe_awq_to_marlin_zero_points,
verify_marlin_supported,
verify_marlin_supports_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@@ -63,6 +62,90 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
# AWQ uses a non-standard packing order within int32 values.
# For 4-bit: standard order stores values at bit positions [0,4,8,12,16,20,24,28]
# for indices [0,1,2,3,4,5,6,7], while AWQ stores them for indices
# [0,4,1,5,2,6,3,7]. This permutation reverses that ordering.
_REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def _convert_awq_to_standard_format(
layer: torch.nn.Module,
w_q_name: str,
w_zp_name: str,
size_bits: int,
) -> None:
"""Convert AWQ weight and zero-point tensors to standard GPTQ-like format.
AWQ packs qweight along the output dim with a non-standard bit order.
This converts to standard bit order and repacks qweight along the input
dim, matching the format expected by the MPLinearKernel framework.
"""
pack_factor = 32 // size_bits
mask = (1 << size_bits) - 1
device = getattr(layer, w_q_name).device
reverse_order = torch.tensor(
_REVERSE_AWQ_PACK_ORDER, dtype=torch.long, device=device
)
shifts = torch.arange(0, 32, size_bits, dtype=torch.int32, device=device)
# --- Convert qweight: (K, N // pack) packed_dim=1 → (K // pack, N) packed_dim=0
qw = getattr(layer, w_q_name).data
K, N_packed = qw.shape
N = N_packed * pack_factor
# Unpack int32 → individual values, fix AWQ ordering
unpacked = (qw.unsqueeze(-1) >> shifts) & mask # (K, N_packed, pack_factor)
unpacked = unpacked[:, :, reverse_order]
unpacked = unpacked.reshape(K, N) # (K, N)
# Repack along input dim (dim 0)
unpacked = unpacked.reshape(K // pack_factor, pack_factor, N)
new_qw = (unpacked.to(torch.int32) << shifts[None, :, None]).sum(
dim=1, dtype=torch.int32
)
def _noop_loader(*args, **kwargs):
pass
new_param = PackedvLLMParameter(
data=new_qw.contiguous(),
input_dim=0,
output_dim=1,
packed_dim=0,
packed_factor=pack_factor,
weight_loader=_noop_loader,
)
setattr(layer, w_q_name, new_param)
# --- Convert qzeros: fix AWQ bit ordering and repack
# AWQ qzeros: (G, N // pack) packed along dim 1, AWQ bit order
# Target: (N // pack, G) packed along dim 0, standard bit order
# This matches the CompressedTensors layout expected by the kernels.
qz = getattr(layer, w_zp_name).data
G, _ = qz.shape
unpacked_zp = (qz.unsqueeze(-1) >> shifts) & mask # (G, N_packed, pack_factor)
unpacked_zp = unpacked_zp[:, :, reverse_order]
unpacked_zp = unpacked_zp.reshape(G, N) # (G, N) individual values
# Transpose and repack along dim 0 (output dim)
unpacked_zp = unpacked_zp.T # (N, G)
unpacked_zp = unpacked_zp.reshape(N // pack_factor, pack_factor, G)
new_qz = (unpacked_zp.to(torch.int32) << shifts[None, :, None]).sum(
dim=1, dtype=torch.int32
)
new_zp_param = PackedvLLMParameter(
data=new_qz.contiguous(),
output_dim=0,
input_dim=1,
packed_dim=0,
packed_factor=pack_factor,
weight_loader=_noop_loader,
)
setattr(layer, w_zp_name, new_zp_param)
class AWQMarlinConfig(QuantizationConfig):
"""Config class for AWQ Marlin"""
@@ -226,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig):
group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point")
if not current_platform.is_cuda():
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
return False
if quant_method != "awq":
@@ -268,15 +351,26 @@ class AWQMarlinConfig(QuantizationConfig):
class AWQMarlinLinearMethod(LinearMethodBase):
"""Linear method for AWQ Marlin.
Uses choose_mp_linear_kernel to select the best available kernel
(Conch, Exllama, or Marlin) for the current platform.
Args:
quant_config: The AWQ Marlin quantization config.
"""
_kernel_backends_being_used: set[str] = set()
def __init__(self, quant_config: AWQMarlinConfig) -> None:
self.quant_config = quant_config
self.quant_type = scalar_types.uint4
self.input_dtype = None
verify_marlin_supported(
quant_type=self.quant_config.quant_type,
group_size=self.quant_config.group_size,
has_zp=self.quant_config.zero_point,
)
def create_weights(
self,
layer: torch.nn.Module,
@@ -287,23 +381,35 @@ class AWQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size,
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=(
input_size_per_partition,
output_size_per_partition,
),
weight_type=self.quant_config.quant_type,
act_type=params_dtype if self.input_dtype is None else self.input_dtype,
group_size=self.quant_config.group_size,
zero_points=self.quant_config.zero_point,
has_g_idx=False,
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for AWQMarlinLinearMethod", kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Weights are loaded in AWQ checkpoint format (packed along output dim).
# Conversion to GPTQ-like format happens in process_weights_after_loading.
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
@@ -318,7 +424,6 @@ class AWQMarlinLinearMethod(LinearMethodBase):
)
num_groups = input_size_per_partition // group_size
layer.num_groups = num_groups
qzeros = PackedvLLMParameter(
data=torch.empty(
@@ -348,73 +453,22 @@ class AWQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups
self.kernel = kernel_type(
mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
)
# TODO: Update this docs
# Checkpoints are serialized in AutoAWQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
# Allocate marlin workspace
layer.workspace = marlin_make_workspace_new(device)
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
if self.input_dtype == torch.float8_e4m3fn:
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
layer.scales.data = layer.scales.data * 512
# Repack weights from AWQ format to marlin format.
marlin_qweight = ops.awq_marlin_repack(
layer.qweight,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
# AWQ checkpoints use a non-standard packing order and pack qweight
# along the output dimension. Convert to the standard format
# (GPTQ-like: standard bit order, qweight packed along input dim)
# before handing off to the kernel.
_convert_awq_to_standard_format(
layer, "qweight", "qzeros", self.quant_config.quant_type.size_bits
)
replace_parameter(layer, "qweight", marlin_qweight)
# Permute scales from AWQ format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
is_a_8bit=is_a_8bit,
)
if self.input_dtype == torch.int8 and layer.num_groups > 1:
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
marlin_scales
)
layer.register_parameter(
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
)
replace_parameter(layer, "scales", marlin_scales)
# Permute zero-points from AWQ format to marlin format.
marlin_zp = awq_to_marlin_zero_points(
layer.qzeros,
size_k=layer.num_groups,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
is_a_8bit=is_a_8bit,
)
replace_parameter(layer, "qzeros", marlin_zp)
# Not-used
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if hasattr(layer, "bias") and layer.bias is not None:
layer.bias.data = marlin_permute_bias(layer.bias)
self.kernel.process_weights_after_loading(layer)
def apply(
self,
@@ -422,21 +476,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
return apply_awq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.qzeros,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
quant_type=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
input_global_scale=getattr(layer, "input_global_scale", None),
bias=bias,
input_dtype=self.input_dtype,
)
return self.kernel.apply_weights(layer, x, bias)
class AWQMarlinMoEMethod(FusedMoEMethodBase):

View File

@@ -46,14 +46,15 @@ def query_marlin_supported_quant_types(
if current_platform.is_cpu():
return _query_cpu_marlin_supported_quant_types(has_zp, include_fp_type)
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int()
)
if not current_platform.is_rocm():
if device_capability is None:
capability_tuple = current_platform.get_device_capability()
device_capability = (
-1 if capability_tuple is None else capability_tuple.to_int()
)
if device_capability < 75:
return []
if device_capability < 75:
return []
# - has_zp is True: return quant_types that has zero points
# - has_zp is False: return quant_types that has not zero points
@@ -210,8 +211,6 @@ def check_marlin_supports_shape(
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
if current_platform.is_rocm():
return False
output_size_per_partition = (
getattr(layer, "output_size_per_partition", None) or layer.output_size
)