[ROCm][Refactor] Enable AWQMarlinConfig on ROCm to use choose_mp_linear_kernel (#36505)
Signed-off-by: Matthias Gehre <matthias.gehre@amd.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -113,6 +113,8 @@ class ConchLinearKernel(MPLinearKernel):
|
||||
self._transform_param(layer, self.w_s_name, transform_w_s)
|
||||
if self.config.zero_points:
|
||||
self._transform_param(layer, self.w_zp_name, transform_w_zp)
|
||||
elif self.w_zp_name is not None:
|
||||
layer.register_parameter(self.w_zp_name, None)
|
||||
|
||||
def apply_weights(
|
||||
self,
|
||||
|
||||
@@ -10,6 +10,10 @@ from torch.nn import Parameter
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.kernels.linear import (
|
||||
MPLinearLayerConfig,
|
||||
choose_mp_linear_kernel,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@@ -34,21 +38,16 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils import replace_parameter
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
apply_awq_marlin_linear,
|
||||
awq_to_marlin_zero_points,
|
||||
check_marlin_supported,
|
||||
check_marlin_supports_layer,
|
||||
check_moe_marlin_supports_layer,
|
||||
get_marlin_input_dtype,
|
||||
marlin_act_int8_process_scales,
|
||||
marlin_make_empty_g_idx,
|
||||
marlin_make_workspace_new,
|
||||
marlin_moe_permute_scales,
|
||||
marlin_permute_bias,
|
||||
marlin_permute_scales,
|
||||
moe_awq_to_marlin_zero_points,
|
||||
verify_marlin_supported,
|
||||
verify_marlin_supports_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
@@ -63,6 +62,90 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# AWQ uses a non-standard packing order within int32 values.
|
||||
# For 4-bit: standard order stores values at bit positions [0,4,8,12,16,20,24,28]
|
||||
# for indices [0,1,2,3,4,5,6,7], while AWQ stores them for indices
|
||||
# [0,4,1,5,2,6,3,7]. This permutation reverses that ordering.
|
||||
_REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
|
||||
|
||||
|
||||
def _convert_awq_to_standard_format(
|
||||
layer: torch.nn.Module,
|
||||
w_q_name: str,
|
||||
w_zp_name: str,
|
||||
size_bits: int,
|
||||
) -> None:
|
||||
"""Convert AWQ weight and zero-point tensors to standard GPTQ-like format.
|
||||
|
||||
AWQ packs qweight along the output dim with a non-standard bit order.
|
||||
This converts to standard bit order and repacks qweight along the input
|
||||
dim, matching the format expected by the MPLinearKernel framework.
|
||||
"""
|
||||
pack_factor = 32 // size_bits
|
||||
mask = (1 << size_bits) - 1
|
||||
device = getattr(layer, w_q_name).device
|
||||
reverse_order = torch.tensor(
|
||||
_REVERSE_AWQ_PACK_ORDER, dtype=torch.long, device=device
|
||||
)
|
||||
shifts = torch.arange(0, 32, size_bits, dtype=torch.int32, device=device)
|
||||
|
||||
# --- Convert qweight: (K, N // pack) packed_dim=1 → (K // pack, N) packed_dim=0
|
||||
qw = getattr(layer, w_q_name).data
|
||||
K, N_packed = qw.shape
|
||||
N = N_packed * pack_factor
|
||||
|
||||
# Unpack int32 → individual values, fix AWQ ordering
|
||||
unpacked = (qw.unsqueeze(-1) >> shifts) & mask # (K, N_packed, pack_factor)
|
||||
unpacked = unpacked[:, :, reverse_order]
|
||||
unpacked = unpacked.reshape(K, N) # (K, N)
|
||||
|
||||
# Repack along input dim (dim 0)
|
||||
unpacked = unpacked.reshape(K // pack_factor, pack_factor, N)
|
||||
new_qw = (unpacked.to(torch.int32) << shifts[None, :, None]).sum(
|
||||
dim=1, dtype=torch.int32
|
||||
)
|
||||
|
||||
def _noop_loader(*args, **kwargs):
|
||||
pass
|
||||
|
||||
new_param = PackedvLLMParameter(
|
||||
data=new_qw.contiguous(),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=pack_factor,
|
||||
weight_loader=_noop_loader,
|
||||
)
|
||||
setattr(layer, w_q_name, new_param)
|
||||
|
||||
# --- Convert qzeros: fix AWQ bit ordering and repack
|
||||
# AWQ qzeros: (G, N // pack) packed along dim 1, AWQ bit order
|
||||
# Target: (N // pack, G) packed along dim 0, standard bit order
|
||||
# This matches the CompressedTensors layout expected by the kernels.
|
||||
qz = getattr(layer, w_zp_name).data
|
||||
G, _ = qz.shape
|
||||
|
||||
unpacked_zp = (qz.unsqueeze(-1) >> shifts) & mask # (G, N_packed, pack_factor)
|
||||
unpacked_zp = unpacked_zp[:, :, reverse_order]
|
||||
unpacked_zp = unpacked_zp.reshape(G, N) # (G, N) individual values
|
||||
|
||||
# Transpose and repack along dim 0 (output dim)
|
||||
unpacked_zp = unpacked_zp.T # (N, G)
|
||||
unpacked_zp = unpacked_zp.reshape(N // pack_factor, pack_factor, G)
|
||||
new_qz = (unpacked_zp.to(torch.int32) << shifts[None, :, None]).sum(
|
||||
dim=1, dtype=torch.int32
|
||||
)
|
||||
|
||||
new_zp_param = PackedvLLMParameter(
|
||||
data=new_qz.contiguous(),
|
||||
output_dim=0,
|
||||
input_dim=1,
|
||||
packed_dim=0,
|
||||
packed_factor=pack_factor,
|
||||
weight_loader=_noop_loader,
|
||||
)
|
||||
setattr(layer, w_zp_name, new_zp_param)
|
||||
|
||||
|
||||
class AWQMarlinConfig(QuantizationConfig):
|
||||
"""Config class for AWQ Marlin"""
|
||||
@@ -226,7 +309,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
group_size = quant_config.get("group_size")
|
||||
zero_point = quant_config.get("zero_point")
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
if not (current_platform.is_cuda_alike() or current_platform.is_cpu()):
|
||||
return False
|
||||
|
||||
if quant_method != "awq":
|
||||
@@ -268,15 +351,26 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ Marlin.
|
||||
|
||||
Uses choose_mp_linear_kernel to select the best available kernel
|
||||
(Conch, Exllama, or Marlin) for the current platform.
|
||||
|
||||
Args:
|
||||
quant_config: The AWQ Marlin quantization config.
|
||||
"""
|
||||
|
||||
_kernel_backends_being_used: set[str] = set()
|
||||
|
||||
def __init__(self, quant_config: AWQMarlinConfig) -> None:
|
||||
self.quant_config = quant_config
|
||||
self.quant_type = scalar_types.uint4
|
||||
self.input_dtype = None
|
||||
|
||||
verify_marlin_supported(
|
||||
quant_type=self.quant_config.quant_type,
|
||||
group_size=self.quant_config.group_size,
|
||||
has_zp=self.quant_config.zero_point,
|
||||
)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -287,23 +381,35 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
del output_size
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight_loader = extra_weight_attrs.get("weight_loader")
|
||||
|
||||
# Normalize group_size
|
||||
if self.quant_config.group_size != -1:
|
||||
group_size = self.quant_config.group_size
|
||||
else:
|
||||
group_size = input_size
|
||||
|
||||
verify_marlin_supports_shape(
|
||||
output_size_per_partition=output_size_per_partition,
|
||||
input_size_per_partition=input_size_per_partition,
|
||||
input_size=input_size,
|
||||
group_size=group_size,
|
||||
mp_linear_kernel_config = MPLinearLayerConfig(
|
||||
full_weight_shape=(input_size, output_size),
|
||||
partition_weight_shape=(
|
||||
input_size_per_partition,
|
||||
output_size_per_partition,
|
||||
),
|
||||
weight_type=self.quant_config.quant_type,
|
||||
act_type=params_dtype if self.input_dtype is None else self.input_dtype,
|
||||
group_size=self.quant_config.group_size,
|
||||
zero_points=self.quant_config.zero_point,
|
||||
has_g_idx=False,
|
||||
)
|
||||
|
||||
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
|
||||
|
||||
if kernel_type.__name__ not in self._kernel_backends_being_used:
|
||||
logger.info("Using %s for AWQMarlinLinearMethod", kernel_type.__name__)
|
||||
self._kernel_backends_being_used.add(kernel_type.__name__)
|
||||
|
||||
# Weights are loaded in AWQ checkpoint format (packed along output dim).
|
||||
# Conversion to GPTQ-like format happens in process_weights_after_loading.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition,
|
||||
@@ -318,7 +424,6 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
)
|
||||
|
||||
num_groups = input_size_per_partition // group_size
|
||||
layer.num_groups = num_groups
|
||||
|
||||
qzeros = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
@@ -348,73 +453,22 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
layer.register_parameter("qzeros", qzeros)
|
||||
layer.register_parameter("scales", scales)
|
||||
|
||||
layer.input_size_per_partition = input_size_per_partition
|
||||
layer.output_size_per_partition = output_size_per_partition
|
||||
layer.num_groups = num_groups
|
||||
self.kernel = kernel_type(
|
||||
mp_linear_kernel_config,
|
||||
w_q_param_name="qweight",
|
||||
w_s_param_name="scales",
|
||||
w_zp_param_name="qzeros",
|
||||
)
|
||||
|
||||
# TODO: Update this docs
|
||||
# Checkpoints are serialized in AutoAWQ format, which is different from the
|
||||
# marlin format. This function is called after the weights are loaded.
|
||||
# Here, we handle the repacking
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
device = layer.qweight.device
|
||||
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
||||
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
||||
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
||||
|
||||
# Allocate marlin workspace
|
||||
layer.workspace = marlin_make_workspace_new(device)
|
||||
|
||||
is_a_8bit = self.input_dtype is not None and self.input_dtype.itemsize == 1
|
||||
|
||||
if self.input_dtype == torch.float8_e4m3fn:
|
||||
ops.marlin_int4_fp8_preprocess(layer.qweight, layer.qzeros, inplace=True)
|
||||
layer.scales.data = layer.scales.data * 512
|
||||
|
||||
# Repack weights from AWQ format to marlin format.
|
||||
marlin_qweight = ops.awq_marlin_repack(
|
||||
layer.qweight,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
# AWQ checkpoints use a non-standard packing order and pack qweight
|
||||
# along the output dimension. Convert to the standard format
|
||||
# (GPTQ-like: standard bit order, qweight packed along input dim)
|
||||
# before handing off to the kernel.
|
||||
_convert_awq_to_standard_format(
|
||||
layer, "qweight", "qzeros", self.quant_config.quant_type.size_bits
|
||||
)
|
||||
replace_parameter(layer, "qweight", marlin_qweight)
|
||||
|
||||
# Permute scales from AWQ format to marlin format.
|
||||
marlin_scales = marlin_permute_scales(
|
||||
layer.scales,
|
||||
size_k=layer.input_size_per_partition,
|
||||
size_n=layer.output_size_per_partition,
|
||||
group_size=self.quant_config.group_size,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
if self.input_dtype == torch.int8 and layer.num_groups > 1:
|
||||
marlin_scales, input_global_scale = marlin_act_int8_process_scales(
|
||||
marlin_scales
|
||||
)
|
||||
layer.register_parameter(
|
||||
"input_global_scale", Parameter(input_global_scale, requires_grad=False)
|
||||
)
|
||||
|
||||
replace_parameter(layer, "scales", marlin_scales)
|
||||
|
||||
# Permute zero-points from AWQ format to marlin format.
|
||||
marlin_zp = awq_to_marlin_zero_points(
|
||||
layer.qzeros,
|
||||
size_k=layer.num_groups,
|
||||
size_n=layer.output_size_per_partition,
|
||||
num_bits=self.quant_config.quant_type.size_bits,
|
||||
is_a_8bit=is_a_8bit,
|
||||
)
|
||||
replace_parameter(layer, "qzeros", marlin_zp)
|
||||
|
||||
# Not-used
|
||||
layer.g_idx = marlin_make_empty_g_idx(device)
|
||||
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
layer.bias.data = marlin_permute_bias(layer.bias)
|
||||
self.kernel.process_weights_after_loading(layer)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
@@ -422,21 +476,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return apply_awq_marlin_linear(
|
||||
input=x,
|
||||
weight=layer.qweight,
|
||||
weight_scale=layer.scales,
|
||||
weight_zp=layer.qzeros,
|
||||
g_idx=layer.g_idx,
|
||||
g_idx_sort_indices=layer.g_idx_sort_indices,
|
||||
workspace=layer.workspace,
|
||||
quant_type=self.quant_config.quant_type,
|
||||
output_size_per_partition=layer.output_size_per_partition,
|
||||
input_size_per_partition=layer.input_size_per_partition,
|
||||
input_global_scale=getattr(layer, "input_global_scale", None),
|
||||
bias=bias,
|
||||
input_dtype=self.input_dtype,
|
||||
)
|
||||
return self.kernel.apply_weights(layer, x, bias)
|
||||
|
||||
|
||||
class AWQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
@@ -46,14 +46,15 @@ def query_marlin_supported_quant_types(
|
||||
if current_platform.is_cpu():
|
||||
return _query_cpu_marlin_supported_quant_types(has_zp, include_fp_type)
|
||||
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (
|
||||
-1 if capability_tuple is None else capability_tuple.to_int()
|
||||
)
|
||||
if not current_platform.is_rocm():
|
||||
if device_capability is None:
|
||||
capability_tuple = current_platform.get_device_capability()
|
||||
device_capability = (
|
||||
-1 if capability_tuple is None else capability_tuple.to_int()
|
||||
)
|
||||
|
||||
if device_capability < 75:
|
||||
return []
|
||||
if device_capability < 75:
|
||||
return []
|
||||
|
||||
# - has_zp is True: return quant_types that has zero points
|
||||
# - has_zp is False: return quant_types that has not zero points
|
||||
@@ -210,8 +211,6 @@ def check_marlin_supports_shape(
|
||||
|
||||
|
||||
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
||||
if current_platform.is_rocm():
|
||||
return False
|
||||
output_size_per_partition = (
|
||||
getattr(layer, "output_size_per_partition", None) or layer.output_size
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user