[Core] Set linear_weights directly on the layer (#3977)

This commit is contained in:
Antoni Baum
2024-04-11 13:35:51 -07:00
committed by GitHub
parent 8afca50889
commit a10d3056da
8 changed files with 114 additions and 102 deletions

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import List, Optional
import torch
import torch.nn.functional as F
@@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Create weights for a linear layer."""
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer."""
raise NotImplementedError
@abstractmethod
def apply_weights(self,
weights: Dict[str, torch.Tensor],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Apply the weights to the input tensor."""
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise NotImplementedError
@@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add
def create_weights(self, input_size_per_partition: int,
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
return {"weight": weight}
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply_weights(self,
weights: Dict[str, torch.Tensor],
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
weight = weights["weight"]
weight = layer.weight
if self.separate_bias_add:
if bias is not None:
return F.linear(x, weight) + bias
@@ -111,12 +118,9 @@ class ReplicatedLinear(torch.nn.Module):
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
self.linear_method.create_weights(self, self.input_size,
self.output_size, self.input_size,
self.output_size, self.params_dtype)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
@@ -126,7 +130,7 @@ class ReplicatedLinear(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
output = self.linear_method.apply_weights(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -177,13 +181,13 @@ class ColumnParallelLinear(torch.nn.Module):
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size_per_partition, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
self.linear_method.create_weights(self,
self.input_size,
self.output_size_per_partition,
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
@@ -211,8 +215,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_, bias)
output_parallel = self.linear_method.apply_weights(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
output = tensor_model_parallel_all_gather(output_parallel)
@@ -523,13 +526,13 @@ class RowParallelLinear(torch.nn.Module):
if linear_method is None:
linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights(
self.input_size_per_partition, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
self.linear_method.create_weights(self,
self.input_size_per_partition,
self.output_size,
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
@@ -569,7 +572,7 @@ class RowParallelLinear(torch.nn.Module):
# Matrix multiply.
output_parallel = self.linear_method.apply_weights(
self.linear_weights, input_parallel)
self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)
else: