Add GPTQ support (#916)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -21,8 +21,10 @@ class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Create weights for a linear layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
def __init__(self, separate_bias_add: bool = False):
|
||||
self.separate_bias_add = separate_bias_add
|
||||
|
||||
def create_weights(self, input_size: int, output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]:
|
||||
weight = Parameter(torch.empty(output_size,
|
||||
input_size,
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
device=torch.cuda.current_device(),
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
@@ -102,9 +106,11 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size, self.params_dtype)
|
||||
self.input_size, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size,
|
||||
@@ -168,10 +174,12 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size_per_partition, self.params_dtype)
|
||||
self.input_size, self.output_size_per_partition, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@@ -295,10 +303,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"MergedColumnParallelLinear, assume the weight is "
|
||||
"the same for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -418,10 +428,12 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
|
||||
shard_size)
|
||||
else:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
ignore_warning = getattr(param, "ignore_warning", False)
|
||||
if not ignore_warning:
|
||||
logger.warning(
|
||||
"Loading a weight without `output_dim` attribute in "
|
||||
"QKVParallelLinear, assume the weight is the same "
|
||||
"for all partitions.")
|
||||
assert param_data.shape == loaded_weight.shape
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
@@ -481,10 +493,12 @@ class RowParallelLinear(torch.nn.Module):
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size_per_partition, self.output_size, self.params_dtype)
|
||||
self.input_size_per_partition, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
|
||||
Reference in New Issue
Block a user