[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GPTQMarlin (#7701)
Co-authored-by: mgoin <michael@neuralmagic.com> Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
@@ -328,6 +328,64 @@ class PackedvLLMParameter(ModelWeightParameter):
|
||||
marlin_tile_size=self.marlin_tile_size)
|
||||
|
||||
|
||||
def permute_param_layout_(param: BasevLLMParameter, input_dim: int,
|
||||
output_dim: int, **kwargs) -> BasevLLMParameter:
|
||||
"""
|
||||
Permute a parameter's layout to the specified input and output dimensions,
|
||||
useful for forcing the parameter into a known layout, for example, if I need
|
||||
a packed (quantized) weight matrix to be in the layout
|
||||
{input_dim = 0, output_dim = 1, packed_dim = 0}
|
||||
then I can call:
|
||||
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
||||
to ensure x is in the correct layout (permuting it to the correct layout if
|
||||
required, asserting if it cannot get it to the correct layout)
|
||||
"""
|
||||
|
||||
curr_input_dim = getattr(param, "input_dim", None)
|
||||
curr_output_dim = getattr(param, "output_dim", None)
|
||||
|
||||
if curr_input_dim is None or curr_output_dim is None:
|
||||
assert param.data.dim() == 2,\
|
||||
"permute_param_layout_ only supports 2D parameters when either "\
|
||||
"input_dim or output_dim is not set"
|
||||
|
||||
# if one of the dimensions is not set, set it to the opposite of the other
|
||||
# we can only do this since we asserted the parameter is 2D above
|
||||
if curr_input_dim is None:
|
||||
assert curr_output_dim is not None,\
|
||||
"either input or output dim must be set"
|
||||
curr_input_dim = (curr_output_dim + 1) % 2
|
||||
if curr_output_dim is None:
|
||||
assert curr_input_dim is not None,\
|
||||
"either input or output dim must be set"
|
||||
curr_output_dim = (curr_input_dim + 1) % 2
|
||||
|
||||
# create permutation from the current layout to the layout with
|
||||
# self.input_dim at input_dim and self.output_dim at output_dim preserving
|
||||
# other dimensions
|
||||
perm = [
|
||||
i for i in range(param.data.dim())
|
||||
if i not in [curr_input_dim, curr_output_dim]
|
||||
]
|
||||
perm.insert(input_dim, curr_input_dim)
|
||||
perm.insert(output_dim, curr_output_dim)
|
||||
|
||||
if "packed_dim" in kwargs:
|
||||
assert hasattr(param, "packed_dim") and\
|
||||
param.packed_dim == perm[kwargs["packed_dim"]],\
|
||||
"permute_param_layout_ currently doesn't support repacking"
|
||||
|
||||
param.data = param.data.permute(*perm)
|
||||
if hasattr(param, "_input_dim"):
|
||||
param._input_dim = input_dim
|
||||
if hasattr(param, "_output_dim"):
|
||||
param._output_dim = output_dim
|
||||
if "packed_dim" in kwargs and hasattr(param, "_packed_dim"):
|
||||
param._packed_dim = kwargs["packed_dim"]
|
||||
|
||||
return param
|
||||
|
||||
|
||||
def _adjust_shard_indexes_for_marlin(shard_size, shard_offset,
|
||||
marlin_tile_size):
|
||||
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
|
||||
|
||||
Reference in New Issue
Block a user