[LoRA] Remove linear hack outside transformers backend (#14177)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -152,6 +152,7 @@ class LinearBase(torch.nn.Module):
|
||||
skip_bias_add: If true, skip adding bias but instead return it.
|
||||
params_dtype: Data type for the parameters.
|
||||
quant_config: Quantization configure.
|
||||
return_bias: If true, return bias together with outputs in forward pass.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -162,6 +163,8 @@ class LinearBase(torch.nn.Module):
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -178,9 +181,11 @@ class LinearBase(torch.nn.Module):
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
self.return_bias = return_bias
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -198,20 +203,25 @@ class ReplicatedLinear(LinearBase):
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
@@ -254,12 +264,15 @@ class ReplicatedLinear(LinearBase):
|
||||
f"to a parameter of size {param.size()}")
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
@@ -293,16 +306,20 @@ class ColumnParallelLinear(LinearBase):
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = input_size
|
||||
@@ -315,8 +332,13 @@ class ColumnParallelLinear(LinearBase):
|
||||
for output_size in self.output_sizes
|
||||
]
|
||||
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config, prefix)
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.gather_output = gather_output
|
||||
|
||||
@@ -393,7 +415,9 @@ class ColumnParallelLinear(LinearBase):
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
def forward(
|
||||
self, input_
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
@@ -405,6 +429,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
else:
|
||||
output = output_parallel
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
@@ -439,15 +465,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert all(output_size % tp_size == 0 for output_size in output_sizes)
|
||||
@@ -458,7 +488,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def weight_loader(self,
|
||||
param: Parameter,
|
||||
@@ -711,16 +742,20 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
(e.g. model.layers.0.qkv_proj)
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
head_size: int,
|
||||
total_num_heads: int,
|
||||
total_num_kv_heads: Optional[int] = None,
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.head_size = head_size
|
||||
self.total_num_heads = total_num_heads
|
||||
@@ -753,7 +788,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
skip_bias_add=skip_bias_add,
|
||||
params_dtype=params_dtype,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
prefix=prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
def _get_shard_offset_mapping(self, loaded_shard_id: str):
|
||||
shard_offset_mapping = {
|
||||
@@ -1048,16 +1084,20 @@ class RowParallelLinear(LinearBase):
|
||||
quant_config: Quantization configure.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
bias: bool = True,
|
||||
input_is_parallel: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
*,
|
||||
return_bias: bool = True,
|
||||
):
|
||||
# Divide the weight matrix along the first dimension.
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
@@ -1065,8 +1105,13 @@ class RowParallelLinear(LinearBase):
|
||||
self.output_size_per_partition = output_size
|
||||
self.output_partition_sizes = [output_size]
|
||||
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config, prefix)
|
||||
super().__init__(input_size,
|
||||
output_size,
|
||||
skip_bias_add,
|
||||
params_dtype,
|
||||
quant_config,
|
||||
prefix,
|
||||
return_bias=return_bias)
|
||||
|
||||
self.input_is_parallel = input_is_parallel
|
||||
self.reduce_results = reduce_results
|
||||
@@ -1145,7 +1190,9 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
def forward(
|
||||
self, input_
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
@@ -1169,6 +1216,8 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
|
||||
if not self.return_bias:
|
||||
return output
|
||||
return output, output_bias
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
|
||||
Reference in New Issue
Block a user