[Misc] Update gptq_marlin to use new vLLMParameters (#7281)

This commit is contained in:
Dipika Sikka
2024-08-13 14:30:11 -04:00
committed by GitHub
parent 181abbc27d
commit fb377d7e74
8 changed files with 234 additions and 98 deletions

View File

@@ -9,7 +9,7 @@ from vllm.logger import init_logger
__all__ = [
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
"ModelWeightParameter", "ChannelQuantScaleParameter",
"GroupQuantScaleParameter"
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
]
logger = init_logger(__name__)
@@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size = kwargs.get("shard_size")
if isinstance(
self,
PackedvLLMParameter) and self.packed_dim == self.output_dim:
(PackedColumnParameter,
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
@@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
if isinstance(
self,
PackedvLLMParameter) and self.output_dim == self.packed_dim:
(PackedColumnParameter,
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
@@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data.copy_(loaded_weight)
class ModelWeightParameter(_ColumnvLLMParameter):
class RowvLLMParameter(BasevLLMParameter):
"""
Parameter class for linear layer weights. Extends the
_ColumnvLLMParameter by adding loading functionality
for linear layers with row parallel functionality.
Requires an input dimension to be defined.
Parameter class defining weight_loading functionality
(load_row_parallel_weight) for parameters being loaded
into linear layers with row parallel functionality.
Requires an input_dim to be defined.
"""
def __init__(self, input_dim: int, **kwargs):
@@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
self.data.copy_(loaded_weight)
class GroupQuantScaleParameter(ModelWeightParameter):
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
grouped quantization. Equivalent to ModelWeightParameter.
grouped quantization. Uses both column and row parallelism.
"""
pass
@@ -171,7 +181,7 @@ class GroupQuantScaleParameter(ModelWeightParameter):
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
"""
Parameter class for weight scales loaded for weights with
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
"""
pass
@@ -181,7 +191,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
Parameter class for scales where the number of scales is
equivalent to the number of logical matrices in fused linear
layers (e.g. for QKV, there are 3 scales loaded from disk).
This is relevant to weights with per-tensor quantization.
This is relevant to weights with per-tensor quantization.
Adds functionality to map the scalers to a shard during
weight loading.
@@ -232,15 +242,11 @@ class PerTensorScaleParameter(BasevLLMParameter):
param_data.copy_(loaded_weight)
class PackedvLLMParameter(ModelWeightParameter):
class PackedColumnParameter(_ColumnvLLMParameter):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def __init__(self,
@@ -250,7 +256,7 @@ class PackedvLLMParameter(ModelWeightParameter):
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile = marlin_tile_size
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)
@property
@@ -262,16 +268,70 @@ class PackedvLLMParameter(ModelWeightParameter):
return self._packed_factor
@property
def marlin_tile(self):
return self._marlin_tile
def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
def marlin_tile_size(self):
return self._marlin_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
shard_size = shard_size // self.packed_factor
shard_offset = shard_offset // self.packed_factor
if self.marlin_tile is not None:
return self._adjust_shard_indexes_for_marlin(
shard_size, shard_offset)
return shard_size, shard_offset
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
class PackedvLLMParameter(ModelWeightParameter):
"""
Parameter for model weights which are packed on disk.
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
"""
def __init__(self,
packed_factor: int,
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
self._marlin_tile_size = marlin_tile_size
super().__init__(**kwargs)
@property
def packed_dim(self):
return self._packed_dim
@property
def packed_factor(self):
return self._packed_factor
@property
def marlin_tile_size(self):
return self._marlin_tile_size
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
return _adjust_shard_indexes_for_packing(
shard_size=shard_size,
shard_offset=shard_offset,
packed_factor=self.packed_factor,
marlin_tile_size=self.marlin_tile_size)
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
marlin_tile_size):
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
marlin_tile_size):
shard_size = shard_size // packed_factor
shard_offset = shard_offset // packed_factor
if marlin_tile_size is not None:
return _adjust_shard_indexes_for_marlin(
shard_size=shard_size,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset