[Misc] Update gptq_marlin to use new vLLMParameters (#7281)
This commit is contained in:
@@ -9,7 +9,7 @@ from vllm.logger import init_logger
|
||||
__all__ = [
|
||||
"BasevLLMParameter", "PackedvLLMParameter", "PerTensorScaleParameter",
|
||||
"ModelWeightParameter", "ChannelQuantScaleParameter",
|
||||
"GroupQuantScaleParameter"
|
||||
"GroupQuantScaleParameter", "PackedColumnParameter", "RowvLLMParameter"
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
shard_size = kwargs.get("shard_size")
|
||||
if isinstance(
|
||||
self,
|
||||
PackedvLLMParameter) and self.packed_dim == self.output_dim:
|
||||
(PackedColumnParameter,
|
||||
PackedvLLMParameter)) and self.packed_dim == self.output_dim:
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
|
||||
@@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
|
||||
if isinstance(
|
||||
self,
|
||||
PackedvLLMParameter) and self.output_dim == self.packed_dim:
|
||||
(PackedColumnParameter,
|
||||
PackedvLLMParameter)) and self.output_dim == self.packed_dim:
|
||||
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
|
||||
shard_offset=shard_offset, shard_size=shard_size)
|
||||
|
||||
@@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class ModelWeightParameter(_ColumnvLLMParameter):
|
||||
class RowvLLMParameter(BasevLLMParameter):
|
||||
"""
|
||||
Parameter class for linear layer weights. Extends the
|
||||
_ColumnvLLMParameter by adding loading functionality
|
||||
for linear layers with row parallel functionality.
|
||||
Requires an input dimension to be defined.
|
||||
Parameter class defining weight_loading functionality
|
||||
(load_row_parallel_weight) for parameters being loaded
|
||||
into linear layers with row parallel functionality.
|
||||
Requires an input_dim to be defined.
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim: int, **kwargs):
|
||||
@@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
|
||||
self.data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class GroupQuantScaleParameter(ModelWeightParameter):
|
||||
class ModelWeightParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for linear layer weights. Uses both column and
|
||||
row parallelism.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class GroupQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
||||
"""
|
||||
Parameter class for weight scales loaded for weights with
|
||||
grouped quantization. Equivalent to ModelWeightParameter.
|
||||
grouped quantization. Uses both column and row parallelism.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -171,7 +181,7 @@ class GroupQuantScaleParameter(ModelWeightParameter):
|
||||
class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
||||
"""
|
||||
Parameter class for weight scales loaded for weights with
|
||||
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
|
||||
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -181,7 +191,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
Parameter class for scales where the number of scales is
|
||||
equivalent to the number of logical matrices in fused linear
|
||||
layers (e.g. for QKV, there are 3 scales loaded from disk).
|
||||
This is relevant to weights with per-tensor quantization.
|
||||
This is relevant to weights with per-tensor quantization.
|
||||
Adds functionality to map the scalers to a shard during
|
||||
weight loading.
|
||||
|
||||
@@ -232,15 +242,11 @@ class PerTensorScaleParameter(BasevLLMParameter):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
class PackedvLLMParameter(ModelWeightParameter):
|
||||
class PackedColumnParameter(_ColumnvLLMParameter):
|
||||
"""
|
||||
Parameter for model weights which are packed on disk.
|
||||
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
|
||||
Extends the ModelWeightParameter to take in the
|
||||
packed factor, the packed dimension, and optionally, marlin
|
||||
tile size for marlin kernels. Adjusts the shard_size and
|
||||
shard_offset for fused linear layers model weight loading
|
||||
by accounting for packing and optionally, marlin tile size.
|
||||
Parameter for model parameters which are packed on disk
|
||||
and support column parallelism only. See PackedvLLMParameter
|
||||
for more details on the packed properties.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@@ -250,7 +256,7 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
**kwargs):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile = marlin_tile_size
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
@@ -262,16 +268,70 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
return self._packed_factor
|
||||
|
||||
@property
|
||||
def marlin_tile(self):
|
||||
return self._marlin_tile
|
||||
|
||||
def _adjust_shard_indexes_for_marlin(self, shard_size, shard_offset):
|
||||
return shard_size * self.marlin_tile, shard_offset * self.marlin_tile
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
shard_size = shard_size // self.packed_factor
|
||||
shard_offset = shard_offset // self.packed_factor
|
||||
if self.marlin_tile is not None:
|
||||
return self._adjust_shard_indexes_for_marlin(
|
||||
shard_size, shard_offset)
|
||||
return shard_size, shard_offset
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
|
||||
|
||||
class PackedvLLMParameter(ModelWeightParameter):
|
||||
"""
|
||||
Parameter for model weights which are packed on disk.
|
||||
Example: GPTQ Marlin weights are int4 or int8, packed into int32.
|
||||
Extends the ModelWeightParameter to take in the
|
||||
packed factor, the packed dimension, and optionally, marlin
|
||||
tile size for marlin kernels. Adjusts the shard_size and
|
||||
shard_offset for fused linear layers model weight loading
|
||||
by accounting for packing and optionally, marlin tile size.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
packed_factor: int,
|
||||
packed_dim: int,
|
||||
marlin_tile_size: Optional[int] = None,
|
||||
**kwargs):
|
||||
self._packed_factor = packed_factor
|
||||
self._packed_dim = packed_dim
|
||||
self._marlin_tile_size = marlin_tile_size
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def packed_dim(self):
|
||||
return self._packed_dim
|
||||
|
||||
@property
|
||||
def packed_factor(self):
|
||||
return self._packed_factor
|
||||
|
||||
@property
|
||||
def marlin_tile_size(self):
|
||||
return self._marlin_tile_size
|
||||
|
||||
def adjust_shard_indexes_for_packing(self, shard_size, shard_offset):
|
||||
return _adjust_shard_indexes_for_packing(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
packed_factor=self.packed_factor,
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
||||
marlin_tile_size):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
|
||||
marlin_tile_size):
|
||||
shard_size = shard_size // packed_factor
|
||||
shard_offset = shard_offset // packed_factor
|
||||
if marlin_tile_size is not None:
|
||||
return _adjust_shard_indexes_for_marlin(
|
||||
shard_size=shard_size,
|
||||
shard_offset=shard_offset,
|
||||
marlin_tile_size=marlin_tile_size)
|
||||
return shard_size, shard_offset
|
||||
|
||||
Reference in New Issue
Block a user