[Core] Set linear_weights directly on the layer (#3977)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -28,19 +28,24 @@ class LinearMethodBase(ABC):
|
||||
"""Base class for different (maybe quantized) linear methods."""
|
||||
|
||||
@abstractmethod
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
"""Create weights for a linear layer."""
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a linear layer.
|
||||
|
||||
The weights will be set as attributes of the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Apply the weights to the input tensor."""
|
||||
"""Apply the weights in layer to the input tensor.
|
||||
|
||||
Expects create_weights to have been called before on the layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -55,22 +60,24 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
def __init__(self, separate_bias_add: bool = False):
|
||||
self.separate_bias_add = separate_bias_add
|
||||
|
||||
def create_weights(self, input_size_per_partition: int,
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_size_per_partition: int, input_size: int,
|
||||
output_size: int,
|
||||
params_dtype: torch.dtype) -> Dict[str, Any]:
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
weight = Parameter(torch.empty(output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
dtype=params_dtype),
|
||||
requires_grad=False)
|
||||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
|
||||
return {"weight": weight}
|
||||
layer.register_parameter("weight", weight)
|
||||
set_weight_attrs(weight, extra_weight_attrs)
|
||||
|
||||
def apply_weights(self,
|
||||
weights: Dict[str, torch.Tensor],
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
weight = weights["weight"]
|
||||
weight = layer.weight
|
||||
if self.separate_bias_add:
|
||||
if bias is not None:
|
||||
return F.linear(x, weight) + bias
|
||||
@@ -111,12 +118,9 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
self.linear_method.create_weights(self, self.input_size,
|
||||
self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size, dtype=self.params_dtype))
|
||||
@@ -126,7 +130,7 @@ class ReplicatedLinear(torch.nn.Module):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
output = self.linear_method.apply_weights(self.linear_weights, x, bias)
|
||||
output = self.linear_method.apply_weights(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
@@ -177,13 +181,13 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size, self.output_size_per_partition, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size,
|
||||
self.output_size_per_partition,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
if bias:
|
||||
self.bias = Parameter(
|
||||
torch.empty(self.output_size_per_partition,
|
||||
@@ -211,8 +215,7 @@ class ColumnParallelLinear(torch.nn.Module):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_, bias)
|
||||
output_parallel = self.linear_method.apply_weights(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
output = tensor_model_parallel_all_gather(output_parallel)
|
||||
@@ -523,13 +526,13 @@ class RowParallelLinear(torch.nn.Module):
|
||||
if linear_method is None:
|
||||
linear_method = UnquantizedLinearMethod()
|
||||
self.linear_method = linear_method
|
||||
self.linear_weights = self.linear_method.create_weights(
|
||||
self.input_size_per_partition, self.output_size, self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
for name, weight in self.linear_weights.items():
|
||||
if isinstance(weight, torch.Tensor):
|
||||
self.register_parameter(name, weight)
|
||||
set_weight_attrs(weight, {"weight_loader": self.weight_loader})
|
||||
self.linear_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
self.output_size,
|
||||
self.input_size,
|
||||
self.output_size,
|
||||
self.params_dtype,
|
||||
weight_loader=self.weight_loader)
|
||||
|
||||
if not reduce_results and (bias and not skip_bias_add):
|
||||
raise ValueError("When not reduce the results, adding bias to the "
|
||||
@@ -569,7 +572,7 @@ class RowParallelLinear(torch.nn.Module):
|
||||
|
||||
# Matrix multiply.
|
||||
output_parallel = self.linear_method.apply_weights(
|
||||
self.linear_weights, input_parallel)
|
||||
self, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user