[Quantization][Deprecation] Remove Marlin 24 (#32688)
Signed-off-by: Robert Shaw <robshaw@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Robert Shaw <robshaw@redhat.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,6 @@ QuantizationMethods = Literal[
|
||||
"modelopt",
|
||||
"modelopt_fp4",
|
||||
"gguf",
|
||||
"gptq_marlin_24",
|
||||
"gptq_marlin",
|
||||
"awq_marlin",
|
||||
"gptq",
|
||||
@@ -41,7 +40,6 @@ DEPRECATED_QUANTIZATION_METHODS = [
|
||||
"ptpc_fp8",
|
||||
"fbgemm_fp8",
|
||||
"fp_quant",
|
||||
"gptq_marlin_24",
|
||||
"experts_int8",
|
||||
"ipex",
|
||||
"petit_nvfp4",
|
||||
@@ -122,7 +120,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
from .gguf import GGUFConfig
|
||||
from .gptq import GPTQConfig
|
||||
from .gptq_marlin import GPTQMarlinConfig
|
||||
from .gptq_marlin_24 import GPTQMarlin24Config
|
||||
from .inc import INCConfig
|
||||
from .ipex_quant import IPEXConfig
|
||||
from .modelopt import ModelOptFp8Config, ModelOptNvFp4Config
|
||||
@@ -140,7 +137,6 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
|
||||
"modelopt": ModelOptFp8Config,
|
||||
"modelopt_fp4": ModelOptNvFp4Config,
|
||||
"gguf": GGUFConfig,
|
||||
"gptq_marlin_24": GPTQMarlin24Config,
|
||||
"gptq_marlin": GPTQMarlinConfig,
|
||||
"awq_marlin": AWQMarlinConfig,
|
||||
"gptq": GPTQConfig,
|
||||
|
||||
@@ -40,7 +40,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
|
||||
CompressedTensorsMoEMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS,
|
||||
WNA16_SUPPORTED_BITS,
|
||||
CompressedTensors24,
|
||||
CompressedTensorsScheme,
|
||||
@@ -49,7 +48,6 @@ from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsW4A8Int,
|
||||
CompressedTensorsW4A16Fp4,
|
||||
CompressedTensorsW4A16Mxfp4,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
CompressedTensorsW8A8Fp8,
|
||||
CompressedTensorsW8A8Int8,
|
||||
CompressedTensorsW8A16Fp8,
|
||||
@@ -610,29 +608,19 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
actorder=weight_quant.actorder,
|
||||
)
|
||||
|
||||
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
||||
if (
|
||||
format == CompressionFormat.marlin_24.value
|
||||
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
|
||||
):
|
||||
assert weight_quant.symmetric
|
||||
return CompressedTensorsW4A16Sparse24(
|
||||
strategy=weight_quant.strategy,
|
||||
num_bits=weight_quant.num_bits,
|
||||
group_size=weight_quant.group_size,
|
||||
)
|
||||
if (
|
||||
format == CompressionFormat.pack_quantized.value
|
||||
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
||||
):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
if (
|
||||
self._is_wNa16_group_channel(weight_quant, input_quant)
|
||||
and (format == CompressionFormat.pack_quantized.value)
|
||||
and (weight_quant.num_bits in WNA16_SUPPORTED_BITS)
|
||||
):
|
||||
return CompressedTensorsWNA16(
|
||||
num_bits=weight_quant.num_bits,
|
||||
strategy=weight_quant.strategy,
|
||||
symmetric=weight_quant.symmetric,
|
||||
group_size=weight_quant.group_size,
|
||||
actorder=weight_quant.actorder,
|
||||
layer_name=layer_name,
|
||||
)
|
||||
|
||||
act_quant_format = is_activation_quantization_format(format)
|
||||
if act_quant_format:
|
||||
|
||||
@@ -5,10 +5,6 @@ from .compressed_tensors_scheme import CompressedTensorsScheme
|
||||
from .compressed_tensors_w4a4_nvfp4 import CompressedTensorsW4A4Fp4
|
||||
from .compressed_tensors_w4a8_fp8 import CompressedTensorsW4A8Fp8
|
||||
from .compressed_tensors_w4a8_int import CompressedTensorsW4A8Int
|
||||
from .compressed_tensors_w4a16_24 import (
|
||||
W4A16SPARSE24_SUPPORTED_BITS,
|
||||
CompressedTensorsW4A16Sparse24,
|
||||
)
|
||||
from .compressed_tensors_w4a16_mxfp4 import CompressedTensorsW4A16Mxfp4
|
||||
from .compressed_tensors_w4a16_nvfp4 import CompressedTensorsW4A16Fp4
|
||||
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
||||
@@ -23,11 +19,9 @@ __all__ = [
|
||||
"CompressedTensorsScheme",
|
||||
"CompressedTensorsWNA16",
|
||||
"CompressedTensorsW8A16Fp8",
|
||||
"CompressedTensorsW4A16Sparse24",
|
||||
"CompressedTensorsW8A8Int8",
|
||||
"CompressedTensorsW8A8Fp8",
|
||||
"WNA16_SUPPORTED_BITS",
|
||||
"W4A16SPARSE24_SUPPORTED_BITS",
|
||||
"CompressedTensors24",
|
||||
"CompressedTensorsW4A16Fp4",
|
||||
"CompressedTensorsW4A16Mxfp4",
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.nn import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
|
||||
CompressedTensorsScheme,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL,
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
__all__ = ["CompressedTensorsW4A16Sparse24"]
|
||||
W4A16SPARSE24_SUPPORTED_TYPES_MAP = {
|
||||
4: scalar_types.uint4b8,
|
||||
}
|
||||
W4A16SPARSE24_SUPPORTED_BITS = list(W4A16SPARSE24_SUPPORTED_TYPES_MAP.keys())
|
||||
|
||||
|
||||
class CompressedTensorsW4A16Sparse24(CompressedTensorsScheme):
|
||||
def __init__(self, strategy: str, num_bits: int, group_size: int | None = None):
|
||||
self.strategy = strategy
|
||||
self.group_size = group_size
|
||||
self.tile_size = 16
|
||||
|
||||
if num_bits not in W4A16SPARSE24_SUPPORTED_TYPES_MAP:
|
||||
raise ValueError(
|
||||
f"Unsupported num_bits = {num_bits}. "
|
||||
f"Supported num_bits = {W4A16SPARSE24_SUPPORTED_BITS}"
|
||||
)
|
||||
|
||||
self.quant_type = W4A16SPARSE24_SUPPORTED_TYPES_MAP[num_bits]
|
||||
|
||||
if self.strategy == "group" and self.group_size is None:
|
||||
raise ValueError("group_size must be given when using strategy group")
|
||||
|
||||
@classmethod
|
||||
def get_min_capability(cls) -> int:
|
||||
# ampere + up
|
||||
return 80
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile to be torch.nn.Parameter
|
||||
layer.weight_packed = Parameter(layer.weight_packed.data, requires_grad=False)
|
||||
layer.scale_packed = Parameter(layer.scale_packed.data, requires_grad=False)
|
||||
layer.meta = Parameter(layer.meta.data, requires_grad=False)
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size_per_partition: int,
|
||||
params_dtype: torch.dtype,
|
||||
weight_loader: Callable,
|
||||
**kwargs,
|
||||
):
|
||||
assert params_dtype == torch.float16, (
|
||||
"float16 is required for marlin24 compressed models. Set dtype=torch.float16" # noqa: E501
|
||||
)
|
||||
|
||||
pack_factor = 32 // self.quant_type.size_bits
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.tile_size // 2,
|
||||
output_size_per_partition * self.tile_size // pack_factor,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=pack_factor,
|
||||
marlin_tile_size=self.tile_size,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
input_groups = (
|
||||
1
|
||||
if self.group_size is None
|
||||
else input_size_per_partition // self.group_size
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
|
||||
if self.group_size is not None:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
else:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
|
||||
weight_shape = BasevLLMParameter(
|
||||
data=torch.empty(2, dtype=torch.int64), weight_loader=weight_loader
|
||||
)
|
||||
|
||||
meta = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("weight_packed", qweight)
|
||||
layer.register_parameter("weight_shape", weight_shape)
|
||||
layer.register_parameter("scale_packed", scales)
|
||||
layer.register_parameter("meta", meta)
|
||||
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
) * GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
workspace = Parameter(
|
||||
torch.zeros(max_workspace_size, dtype=torch.int), requires_grad=False
|
||||
)
|
||||
layer.workspace = workspace
|
||||
|
||||
def apply_weights(
|
||||
self, layer: torch.nn.Module, x: torch.Tensor, bias: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.weight_packed
|
||||
meta = layer.meta
|
||||
scales = layer.scale_packed
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(
|
||||
x_2d,
|
||||
qweight,
|
||||
meta,
|
||||
scales,
|
||||
workspace,
|
||||
self.quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -1,320 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BasevLLMParameter,
|
||||
ChannelQuantScaleParameter,
|
||||
GroupQuantScaleParameter,
|
||||
PackedvLLMParameter,
|
||||
)
|
||||
from vllm.scalar_type import scalar_types
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
GPTQ_MARLIN_24_TILE = 16
|
||||
GPTQ_MARLIN_24_MIN_THREAD_N = 128
|
||||
GPTQ_MARLIN_24_MIN_THREAD_K = 128
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL = 64
|
||||
|
||||
GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128]
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES = [-1, 128]
|
||||
|
||||
|
||||
class GPTQMarlin24Config(QuantizationConfig):
|
||||
"""Config class for Marlin24."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight_bits: int,
|
||||
group_size: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
quant_type = {
|
||||
4: scalar_types.uint4b8,
|
||||
8: scalar_types.uint8b128,
|
||||
}.get(weight_bits)
|
||||
|
||||
self.group_size = group_size
|
||||
|
||||
# Verify
|
||||
if quant_type is None or quant_type not in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support quant_type = {quant_type}. "
|
||||
f"Only weight_bits = {GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES} "
|
||||
"are supported."
|
||||
)
|
||||
if self.group_size not in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES:
|
||||
raise ValueError(
|
||||
f"Marlin_24 does not support group_size = {self.group_size}. "
|
||||
f"Only group_sizes = {GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES} "
|
||||
"are supported."
|
||||
)
|
||||
|
||||
self.quant_type = quant_type
|
||||
|
||||
# 4 Bits packed into 32 bit datatype.
|
||||
self.pack_factor = 32 // self.quant_type.size_bits
|
||||
|
||||
# Tile size used by marlin kernels.
|
||||
self.tile_size = 16
|
||||
|
||||
# Min out_features dim
|
||||
self.min_n_threads = GPTQ_MARLIN_24_MIN_THREAD_N
|
||||
|
||||
# Min in_features dim
|
||||
self.min_k_threads = GPTQ_MARLIN_24_MIN_THREAD_K
|
||||
|
||||
# Max parallel problems to solve at once (improves large
|
||||
# batch performance)
|
||||
self.max_parallel = GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
|
||||
# Permutation length used by the marlin kernels.
|
||||
self.perm_len = 1024
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "Marlin24Config(quant_type={}, group_size={})".format(
|
||||
self.quant_type, self.group_size
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
||||
return [torch.half]
|
||||
|
||||
@classmethod
|
||||
# Need to figure it out
|
||||
def get_min_capability(cls) -> int:
|
||||
return 80
|
||||
|
||||
@classmethod
|
||||
def get_config_filenames(cls) -> list[str]:
|
||||
return ["quantize_config.json"]
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: dict[str, Any]) -> "GPTQMarlin24Config":
|
||||
weight_bits = cls.get_from_keys(config, ["bits"])
|
||||
group_size = cls.get_from_keys(config, ["group_size"])
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
is_marlin_24_format = hf_quant_cfg.get("checkpoint_format") == "marlin_24"
|
||||
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "gptq" or user_quant == "gptq_marlin_24"
|
||||
)
|
||||
|
||||
if is_marlin_24_format and is_valid_user_quant:
|
||||
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
||||
cls.get_name(), cls.get_name()
|
||||
)
|
||||
logger.info(msg)
|
||||
return cls.get_name()
|
||||
|
||||
return None
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Optional["GPTQMarlin24LinearMethod"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
return GPTQMarlin24LinearMethod(self)
|
||||
return None
|
||||
|
||||
|
||||
class GPTQMarlin24LinearMethod(LinearMethodBase):
|
||||
"""Linear method for Marlin24.
|
||||
|
||||
Args:
|
||||
quant_config: The Marlin24 quantization config.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: GPTQMarlin24Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: list[int],
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
del output_size # Unused.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
if params_dtype != torch.float16:
|
||||
raise ValueError(
|
||||
f"The params dtype must be float16, but got {params_dtype}"
|
||||
)
|
||||
|
||||
# Validate output_size_per_partition
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"min_n_threads = {self.quant_config.min_n_threads}."
|
||||
)
|
||||
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
||||
raise ValueError(
|
||||
f"Weight output_size_per_partition = "
|
||||
f"{output_size_per_partition} is not divisible by "
|
||||
f"pack_factor = {self.quant_config.pack_factor}."
|
||||
)
|
||||
|
||||
# Validate input_size_per_partition
|
||||
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"min_k_threads = {self.quant_config.min_k_threads}."
|
||||
)
|
||||
if (
|
||||
self.quant_config.group_size != -1
|
||||
and input_size_per_partition % self.quant_config.group_size != 0
|
||||
):
|
||||
raise ValueError(
|
||||
f"Weight input_size_per_partition = "
|
||||
f"{input_size_per_partition} is not divisible by "
|
||||
f"group_size = {self.quant_config.group_size}."
|
||||
)
|
||||
|
||||
# Check that we have at least 4 tiles horizontally in the shard
|
||||
num_tiles_per_perm = self.quant_config.perm_len // (
|
||||
self.quant_config.tile_size**2
|
||||
)
|
||||
if output_size_per_partition % num_tiles_per_perm != 0:
|
||||
raise ValueError("Each permutation group must reside on the same gpu")
|
||||
|
||||
# Quantized 4Bit weights packed into Int32.
|
||||
qweight = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // self.quant_config.tile_size // 2,
|
||||
output_size_per_partition
|
||||
* self.quant_config.tile_size
|
||||
// self.quant_config.pack_factor,
|
||||
device="cuda",
|
||||
dtype=torch.int32,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=self.quant_config.pack_factor,
|
||||
marlin_tile_size=self.quant_config.tile_size,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Meta
|
||||
meta = PackedvLLMParameter(
|
||||
data=torch.empty(
|
||||
input_size_per_partition // 8 // 2 // 2,
|
||||
output_size_per_partition * 2,
|
||||
device="cuda",
|
||||
dtype=torch.int16,
|
||||
),
|
||||
input_dim=0,
|
||||
output_dim=1,
|
||||
packed_dim=1,
|
||||
packed_factor=1,
|
||||
marlin_tile_size=2,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
# Determine if channelwise or not
|
||||
input_groups = (
|
||||
1
|
||||
if self.quant_config.group_size == -1
|
||||
else input_size_per_partition // self.quant_config.group_size
|
||||
)
|
||||
|
||||
weight_scale_args = {
|
||||
"data": torch.empty(
|
||||
input_groups,
|
||||
output_size_per_partition,
|
||||
device="cuda",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
"weight_loader": weight_loader,
|
||||
}
|
||||
if input_groups == 1:
|
||||
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
||||
else:
|
||||
scales = GroupQuantScaleParameter(
|
||||
output_dim=1, input_dim=0, **weight_scale_args
|
||||
)
|
||||
|
||||
# Allocate workspace (Used for internal locking mechanism)
|
||||
max_workspace_size = (
|
||||
output_size_per_partition // self.quant_config.min_n_threads
|
||||
) * self.quant_config.max_parallel
|
||||
|
||||
workspace = BasevLLMParameter(
|
||||
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
|
||||
layer.register_parameter("B_24", qweight)
|
||||
layer.register_parameter("B_meta", meta)
|
||||
layer.register_parameter("s", scales)
|
||||
layer.register_parameter("workspace", workspace)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
# required by torch.compile
|
||||
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
|
||||
layer.s = Parameter(layer.s.data, requires_grad=False)
|
||||
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
|
||||
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qweight = layer.B_24
|
||||
meta = layer.B_meta
|
||||
scales = layer.s
|
||||
workspace = layer.workspace
|
||||
|
||||
x_2d = x.view(-1, x.shape[-1])
|
||||
|
||||
size_m = x_2d.shape[0]
|
||||
size_k = x_2d.shape[1]
|
||||
size_n = scales.shape[1]
|
||||
|
||||
output_2d = ops.gptq_marlin_24_gemm(
|
||||
x_2d,
|
||||
qweight,
|
||||
meta,
|
||||
scales,
|
||||
workspace,
|
||||
self.quant_config.quant_type,
|
||||
size_m,
|
||||
size_n,
|
||||
size_k,
|
||||
)
|
||||
|
||||
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
||||
|
||||
if bias is not None:
|
||||
output.add_(bias) # In-place add
|
||||
|
||||
return output
|
||||
@@ -1,467 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utility functions used for tests and benchmarks"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from vllm.scalar_type import ScalarType
|
||||
|
||||
from .marlin_utils_test import marlin_weights
|
||||
from .quant_utils import gptq_quantize_weights
|
||||
|
||||
|
||||
# This is PyTorch implementation of main part of reorder_meta()
|
||||
# function, from tools/util/include/cutlass/util/host_reorder.h file
|
||||
# of CUTLASS source tree. Furthermore, CUTLASS template for sparse
|
||||
# GEMM decides upon layout of this matrix, and at the moment for the
|
||||
# sparse GEMM executed on tensor cores, this is layout described by
|
||||
# ColumnMajorInterleaved<2> data structure, in
|
||||
# include/cutlass/layout/matrix.h of CUTLASS source tree. The
|
||||
# reordering of meta matrix into meta_reordered matrix calculated
|
||||
# according to these segments of CUTLASS code is re-implemented here.
|
||||
# Note that this calculation produces offsets for scattering metadata
|
||||
# matrix elements into reordered metadata matrix elements (or,
|
||||
# equivalently, for gathering reordered metadata matrix element back
|
||||
# into metadata matrix elements).
|
||||
def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, device):
|
||||
dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols)
|
||||
dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1)
|
||||
|
||||
# Reorder the rows, then swizzle the 2x2 blocks.
|
||||
group_x = 64
|
||||
group_y = 32 if meta_dtype.itemsize == 2 else 16
|
||||
|
||||
dst_rows = (
|
||||
dst_rows // group_x * group_x
|
||||
+ (dst_rows % 2) * 2
|
||||
+ (dst_rows % 8) // 4
|
||||
+ ((dst_rows % group_y) % 4) // 2 * 32
|
||||
+ ((dst_rows % group_x) // 8) * 4
|
||||
)
|
||||
|
||||
topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8)
|
||||
bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8)
|
||||
dst_rows += topright - bottomleft
|
||||
dst_cols -= topright - bottomleft
|
||||
|
||||
# Assumed that meta tensor is to be stored in CUTLASS
|
||||
# InterleavedColumnMajor layout, and reverse engineered
|
||||
# corresponding code to store values into this tensor.
|
||||
interleave = 2
|
||||
cols_maj = dst_cols // interleave
|
||||
cols_min = dst_cols % interleave
|
||||
return (cols_maj * m * interleave + dst_rows * interleave + cols_min).view(-1)
|
||||
|
||||
|
||||
# This function converts dense matrix into sparse semi-structured
|
||||
# representation, producing "compressed" matrix, in the layout used by
|
||||
# CUTLASS backend, and corresponding metadata matrix.
|
||||
def sparse_semi_structured_from_dense_cutlass(dense):
|
||||
if dense.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
|
||||
m, k = dense.shape
|
||||
device = dense.device
|
||||
|
||||
meta_dtype = torch.int8
|
||||
if dense.dtype == torch.int8:
|
||||
meta_dtype = torch.int32
|
||||
elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]:
|
||||
meta_dtype = torch.int16
|
||||
else:
|
||||
raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix")
|
||||
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
||||
if quadbits_per_meta_elem not in (4, 8):
|
||||
raise RuntimeError("Invalid number of elements per meta element calculated")
|
||||
|
||||
if meta_dtype == torch.int32:
|
||||
if m % 16 != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of dense matrix {m} must be divisible by 16"
|
||||
)
|
||||
else:
|
||||
if m % 32 != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of dense matrix {m} must be divisible by 32"
|
||||
)
|
||||
if k % (4 * quadbits_per_meta_elem) != 0:
|
||||
raise RuntimeError(
|
||||
f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501
|
||||
)
|
||||
|
||||
if dense.dtype != torch.float:
|
||||
ksparse = 4
|
||||
dense_4 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m1, m2, m3 = (dense_4 != 0).unbind(-1)
|
||||
else:
|
||||
ksparse = 2
|
||||
dense_2 = dense.view(-1, k // ksparse, ksparse)
|
||||
m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1)
|
||||
meta_ncols = k // (ksparse * quadbits_per_meta_elem)
|
||||
|
||||
# Encoding quadruples of True/False values as follows:
|
||||
# [True, True, False, False] -> 0b0100
|
||||
# [True, False, True, False] -> 0b1000
|
||||
# [False, True, True, False] -> 0b1001
|
||||
# [True, False, False, True ] -> 0b1100
|
||||
# [False, True, False, True ] -> 0b1101
|
||||
# [False, False, True, True ] -> 0b1110
|
||||
# Thus, lower two bits in the encoding are index of the True value
|
||||
# at the lowest index in the quadruple, and the higher two bits in
|
||||
# the encoding are index of the other True value in the quadruple.
|
||||
# In case there are less than two True values, than False value or
|
||||
# values at some index or indices are considered True for the
|
||||
# encoding. In case there are more than two True values, then the
|
||||
# excess True value(s) at some indices are considered False for
|
||||
# the encoding. The exact encodings used for these cases are as
|
||||
# follows:
|
||||
# [False, False, False, False] -> 0b1110
|
||||
# [False, False, False, True ] -> 0b1110
|
||||
# [False, False, True, False] -> 0b1110
|
||||
# [False, True, False, False] -> 0b1001
|
||||
# [False, True, True, True ] -> 0b1101
|
||||
# [True, False, False, False] -> 0b1000
|
||||
# [True, False, True, True ] -> 0b1100
|
||||
# [True, True, False, True ] -> 0b0100
|
||||
# [True, True, True, False] -> 0b0100
|
||||
# [True, True, True, True ] -> 0b0100
|
||||
# These particular encodings are chosen, with the help of Espresso
|
||||
# logic minimizer software, for the purpose of minimization of
|
||||
# corresponding Boolean functions, that translate non-zero flags
|
||||
# into encoding bits. Note also possible choices for the first
|
||||
# and last of these encodings were limited only to (0b0100,
|
||||
# 0b1110), in order to produce valid encodings for 1:2 sparsity
|
||||
# case.
|
||||
|
||||
expr0 = m0 & m1
|
||||
expr1 = ~m0 & m1
|
||||
expr2 = ~m0 & ~m1
|
||||
bit0 = expr1
|
||||
bit1 = expr2
|
||||
bit2 = expr0 | expr2 | m3
|
||||
bit3 = expr1 | ~m1
|
||||
idxs0 = bit0 | (bit1.to(torch.int64) << 1)
|
||||
idxs1 = bit2 | (bit3.to(torch.int64) << 1)
|
||||
|
||||
if dense.dtype != torch.float:
|
||||
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
||||
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
||||
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
||||
else:
|
||||
sparse = dense_2.gather(-1, idxs0.unsqueeze(-1) // 2).view(m, k // 2) # type: ignore[possibly-undefined]
|
||||
|
||||
meta_4 = idxs0 | (idxs1 << 2)
|
||||
meta_n = meta_4.view((-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype)
|
||||
|
||||
if quadbits_per_meta_elem == 4:
|
||||
meta = (
|
||||
meta_n[:, :, 0]
|
||||
| (meta_n[:, :, 1] << 4)
|
||||
| (meta_n[:, :, 2] << 8)
|
||||
| (meta_n[:, :, 3] << 12)
|
||||
)
|
||||
elif quadbits_per_meta_elem == 8:
|
||||
meta = (
|
||||
meta_n[:, :, 0]
|
||||
| (meta_n[:, :, 1] << 4)
|
||||
| (meta_n[:, :, 2] << 8)
|
||||
| (meta_n[:, :, 3] << 12)
|
||||
| (meta_n[:, :, 4] << 16)
|
||||
| (meta_n[:, :, 5] << 20)
|
||||
| (meta_n[:, :, 6] << 24)
|
||||
| (meta_n[:, :, 7] << 28)
|
||||
)
|
||||
|
||||
# Reorder meta tensor elements.
|
||||
meta_reordered = meta.new_empty((m * meta_ncols,)) # type: ignore[possibly-undefined]
|
||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device
|
||||
)
|
||||
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
||||
|
||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||
|
||||
|
||||
# This function performs reverse of the function above - it
|
||||
# reconstructs dense matrix from a pair of "compressed" matrix, given
|
||||
# in the layout used by CUTLASS backend, and accompanying metadata
|
||||
# matrix.
|
||||
def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered):
|
||||
if sparse.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
|
||||
m, k = sparse.shape
|
||||
device = sparse.device
|
||||
|
||||
if meta_reordered.dim() != 2:
|
||||
raise RuntimeError(
|
||||
f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501
|
||||
)
|
||||
if meta_reordered.device != device:
|
||||
raise RuntimeError(
|
||||
f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501
|
||||
)
|
||||
|
||||
meta_dtype = meta_reordered.dtype
|
||||
if meta_dtype not in (torch.int16, torch.int32):
|
||||
raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix")
|
||||
quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4
|
||||
|
||||
ksparse = 4 if sparse.dtype != torch.float else 2
|
||||
|
||||
meta_nrows, meta_ncols = meta_reordered.shape
|
||||
if meta_nrows != m:
|
||||
raise RuntimeError(
|
||||
f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501
|
||||
)
|
||||
if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k:
|
||||
raise RuntimeError(
|
||||
f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501
|
||||
"expected according to the number of columns of meta matrix"
|
||||
)
|
||||
|
||||
# Undo meta tensor elements reordering.
|
||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device
|
||||
)
|
||||
meta = torch.gather(meta_reordered.view(-1), 0, meta_offsets).view(m, meta_ncols)
|
||||
|
||||
# Unpack sparse tensor back to original dense tensor, using
|
||||
# information provided by meta tensor. Note that torch.float
|
||||
# datatype is handled pretty much the same as
|
||||
# torch.half/torch.bfloat16, as metadata for a pair of torch.float
|
||||
# value is encoded as if underlying 8 bytes contain four
|
||||
# torch.half/torch.bfloat16 values, where either first two or last
|
||||
# two are zeros.
|
||||
meta_2 = torch.empty(
|
||||
(m, meta_ncols, 2 * quadbits_per_meta_elem),
|
||||
dtype=meta_dtype,
|
||||
device=device,
|
||||
)
|
||||
if quadbits_per_meta_elem == 4:
|
||||
meta_2[:, :, 0] = meta & 0b11
|
||||
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
||||
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
||||
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
||||
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
||||
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
||||
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
||||
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
||||
elif quadbits_per_meta_elem == 8:
|
||||
meta_2[:, :, 0] = meta & 0b11
|
||||
meta_2[:, :, 1] = (meta >> 2) & 0b11
|
||||
meta_2[:, :, 2] = (meta >> 4) & 0b11
|
||||
meta_2[:, :, 3] = (meta >> 6) & 0b11
|
||||
meta_2[:, :, 4] = (meta >> 8) & 0b11
|
||||
meta_2[:, :, 5] = (meta >> 10) & 0b11
|
||||
meta_2[:, :, 6] = (meta >> 12) & 0b11
|
||||
meta_2[:, :, 7] = (meta >> 14) & 0b11
|
||||
meta_2[:, :, 8] = (meta >> 16) & 0b11
|
||||
meta_2[:, :, 9] = (meta >> 18) & 0b11
|
||||
meta_2[:, :, 10] = (meta >> 20) & 0b11
|
||||
meta_2[:, :, 11] = (meta >> 22) & 0b11
|
||||
meta_2[:, :, 12] = (meta >> 24) & 0b11
|
||||
meta_2[:, :, 13] = (meta >> 26) & 0b11
|
||||
meta_2[:, :, 14] = (meta >> 28) & 0b11
|
||||
meta_2[:, :, 15] = (meta >> 30) & 0b11
|
||||
|
||||
dense_offsets = meta_2.view(-1) + (
|
||||
torch.arange(0, 2 * m * k // ksparse, device=device) * 4
|
||||
).view(-1, 1).repeat(1, 2).view(-1)
|
||||
|
||||
dense = torch.zeros((m * 2 * k,), dtype=sparse.dtype, device=device)
|
||||
if sparse.dtype != torch.float:
|
||||
# dense.scatter_(0, dense_offsets, sparse.view(-1))
|
||||
dense.scatter_(0, dense_offsets, sparse.reshape(-1))
|
||||
else:
|
||||
dense.view(torch.half).scatter_(
|
||||
0, dense_offsets, sparse.view(torch.half).view(-1)
|
||||
)
|
||||
|
||||
return dense.view(m, 2 * k)
|
||||
|
||||
|
||||
def mask_creator(tensor):
|
||||
"""
|
||||
Class for creating N:M sparsity masks.
|
||||
Masks will be created using the N:M ratio, where for every block of
|
||||
M weights, N will be pruned based on ranked weight value. Each mask
|
||||
will correspond to the given tensor.
|
||||
|
||||
:param N: The number of weights in a group to keep
|
||||
:param M: The size of a weight group
|
||||
"""
|
||||
N = 2
|
||||
M = 4
|
||||
|
||||
mask = None
|
||||
# for i, tensor in enumerate(tensors):
|
||||
if tensor.numel() % M != 0:
|
||||
raise ValueError(
|
||||
f"Tensor of size {tensor.shape} can't be evenly divided into {M} groups"
|
||||
)
|
||||
|
||||
num_groups = tensor.numel() // M
|
||||
|
||||
# N:M sparsity for linear layers
|
||||
tensor_temp = tensor.detach().abs().reshape(num_groups, M)
|
||||
index = torch.argsort(tensor_temp, dim=1)[:, : int(M - N)]
|
||||
|
||||
w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device)
|
||||
mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape)
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
def inject_24(w, size_k, size_n):
|
||||
assert w.shape == (size_k, size_n)
|
||||
|
||||
mask = mask_creator(w.t()).t().cuda().bool()
|
||||
|
||||
return (mask * w).contiguous(), mask.contiguous()
|
||||
|
||||
|
||||
def check_24(w, num_rows_to_sample=50, _verbose=False):
|
||||
BLOCK_SIZE = 4
|
||||
MAX_NON_ZEROS = 2
|
||||
|
||||
w = w.t().contiguous()
|
||||
|
||||
print("check_24: w.shape = {}".format(w.shape))
|
||||
|
||||
num_rows, num_cols = w.shape
|
||||
sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample)
|
||||
if _verbose:
|
||||
print(f"Sampled row idxs = {sampled_row_idxs}")
|
||||
|
||||
total_segments = 0
|
||||
non_24_segments = 0
|
||||
for i in sampled_row_idxs:
|
||||
for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE):
|
||||
total_segments += 1
|
||||
block = w[i, j : j + BLOCK_SIZE]
|
||||
num_nonzero = torch.count_nonzero(block)
|
||||
if num_nonzero > MAX_NON_ZEROS:
|
||||
print("i = {} j = {} block = {}".format(i, j, block))
|
||||
non_24_segments += 1
|
||||
|
||||
print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.")
|
||||
|
||||
|
||||
def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType):
|
||||
assert q_24.shape == (size_k, size_n)
|
||||
|
||||
# Remove bias to normalize over 0
|
||||
q_24_no_zp = q_24 - wtype.bias
|
||||
|
||||
# Compress
|
||||
q_24_no_zp = q_24_no_zp.t().contiguous()
|
||||
q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass(q_24_no_zp)
|
||||
q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous()
|
||||
|
||||
# Restore bias
|
||||
q_24_comp = q_24_no_zp_comp + wtype.bias
|
||||
|
||||
# Resize meta to its actual shape (without moving any data)
|
||||
meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2)
|
||||
|
||||
return q_24_comp, meta
|
||||
|
||||
|
||||
def get_scale_perms_24():
|
||||
scale_perm: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]])
|
||||
scale_perm_single: list[int] = []
|
||||
for i in range(8):
|
||||
scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
return scale_perm, scale_perm_single
|
||||
|
||||
|
||||
def get_weight_perm_24(num_bits: int):
|
||||
perm_list: list[int] = []
|
||||
for i in range(32):
|
||||
perm1: list[int] = []
|
||||
col = i // 4
|
||||
col_o = col // 2
|
||||
for block in [0, 1]:
|
||||
for row in [
|
||||
2 * (i % 4),
|
||||
2 * (i % 4) + 1,
|
||||
2 * (i % 4 + 4),
|
||||
2 * (i % 4 + 4) + 1,
|
||||
]:
|
||||
perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + 4 * block)
|
||||
for j in range(4):
|
||||
perm_list.extend([p + 1 * j for p in perm1])
|
||||
perm = numpy.array(perm_list)
|
||||
|
||||
if num_bits == 4:
|
||||
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
||||
elif num_bits == 8:
|
||||
interleave = numpy.array([0, 2, 1, 3])
|
||||
else:
|
||||
raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits))
|
||||
|
||||
perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel()
|
||||
perm = torch.from_numpy(perm)
|
||||
return perm
|
||||
|
||||
|
||||
def marlin_permute_scales_24(
|
||||
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
||||
) -> torch.Tensor:
|
||||
scale_perm, scale_perm_single = get_scale_perms_24()
|
||||
if group_size < size_k and group_size != -1:
|
||||
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
||||
else:
|
||||
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
||||
s = s.reshape((-1, size_n)).contiguous()
|
||||
|
||||
return s
|
||||
|
||||
|
||||
def marlin_24_quantize(
|
||||
w: torch.Tensor,
|
||||
quant_type: ScalarType,
|
||||
group_size: int,
|
||||
):
|
||||
size_k, size_n = w.shape
|
||||
|
||||
# Normalize group_size
|
||||
if group_size == -1:
|
||||
group_size = size_k
|
||||
assert group_size <= size_k
|
||||
|
||||
# Inject 2:4 sparsity
|
||||
w_24, mask_24 = inject_24(w, size_k, size_n)
|
||||
|
||||
# Quantize
|
||||
w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights(
|
||||
w_24, quant_type, group_size, act_order=False
|
||||
)
|
||||
|
||||
# Compress quantized weight
|
||||
q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, quant_type)
|
||||
size_k_comp = size_k // 2
|
||||
|
||||
# Reformat to marlin
|
||||
weight_perm = get_weight_perm_24(quant_type.size_bits)
|
||||
marlin_24_q_w_comp = marlin_weights(
|
||||
q_w_24_comp, size_k_comp, size_n, quant_type.size_bits, weight_perm
|
||||
)
|
||||
marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size)
|
||||
|
||||
# Create result
|
||||
res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s]
|
||||
for i in range(len(res_list)):
|
||||
res_list[i] = res_list[i].to(w.device)
|
||||
|
||||
return res_list
|
||||
Reference in New Issue
Block a user