AQLM CUDA support (#3287)
Co-authored-by: mgoin <michael@neuralmagic.com>
This commit is contained in:
@@ -31,7 +31,7 @@ class LinearMethodBase(ABC):
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a linear layer.
|
||||
@@ -70,9 +70,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
output_size_per_partition = sum(output_partition_sizes)
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
@@ -127,7 +128,7 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self, self.input_size,
|
||||
self.output_size, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
@@ -161,6 +162,8 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
linear_method: (Maybe quantized) linear method.
|
||||
output_sizes: list of output sizes packed into one output, like for QKV
|
||||
the list would be size 3.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -172,6 +175,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -188,10 +192,12 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
self.params_dtype = params_dtype
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size,
|
||||
self.output_size_per_partition,
|
||||
[x // tp_size for x in output_sizes],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
@@ -268,14 +274,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
super().__init__(input_size, sum(output_sizes), bias, gather_output,
|
||||
skip_bias_add, params_dtype, linear_method)
|
||||
skip_bias_add, params_dtype, linear_method,
|
||||
self.output_sizes)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
loaded_weight: torch.Tensor,
|
||||
loaded_shard_id: Optional[int] = None):
|
||||
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
if output_dim is None:
|
||||
@@ -328,6 +337,11 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_offset = loaded_shard_id * shard_size
|
||||
param_data = param_data.narrow(0, shard_offset, shard_size)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
@@ -393,8 +407,14 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
input_size = self.hidden_size
|
||||
output_size = (self.num_heads +
|
||||
2 * self.num_kv_heads) * tp_size * self.head_size
|
||||
output_sizes = [
|
||||
self.num_heads * tp_size * self.head_size,
|
||||
self.num_kv_heads * tp_size * self.head_size,
|
||||
self.num_kv_heads * tp_size * self.head_size
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, bias, False, skip_bias_add,
|
||||
params_dtype, linear_method)
|
||||
params_dtype, linear_method, output_sizes)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@@ -402,6 +422,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_shard_id: Optional[str] = None):
|
||||
param_data = param.data
|
||||
output_dim = getattr(param, "output_dim", None)
|
||||
is_metadata = getattr(param, "is_metadata", False)
|
||||
|
||||
if loaded_shard_id is None:
|
||||
# Loaded weight is already packed.
|
||||
@@ -469,6 +490,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
start_idx = shard_id * shard_size
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
elif is_metadata:
|
||||
# metadata indicates fixed size concatenated along dim 0
|
||||
shard_size = loaded_weight.shape[0]
|
||||
shard_index = ["q", "k", "v"].index(loaded_shard_id)
|
||||
param_data = param_data.narrow(0, shard_index * shard_size,
|
||||
shard_size)
|
||||
else:
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
@@ -536,7 +563,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
self.linear_method = linear_method
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
self.output_size,
|
||||
[self.output_size],
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
|
||||
Reference in New Issue
Block a user