Refactor Linear handling in TransformersModel (#12727)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
import itertools
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -47,8 +47,8 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
|
||||
|
||||
|
||||
def adjust_bitsandbytes_4bit_shard(param: Parameter,
|
||||
shard_offsets: Dict[str, Tuple[int, int]],
|
||||
loaded_shard_id: str) -> Tuple[int, int]:
|
||||
shard_offsets: dict[str, tuple[int, int]],
|
||||
loaded_shard_id: str) -> tuple[int, int]:
|
||||
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
|
||||
|
||||
total, _ = shard_offsets["total"]
|
||||
@@ -90,7 +90,7 @@ class LinearMethodBase(QuantizeMethodBase):
|
||||
@abstractmethod
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
"""Create weights for a linear layer.
|
||||
@@ -123,7 +123,7 @@ class UnquantizedLinearMethod(LinearMethodBase):
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module,
|
||||
input_size_per_partition: int,
|
||||
output_partition_sizes: List[int], input_size: int,
|
||||
output_partition_sizes: list[int], input_size: int,
|
||||
output_size: int, params_dtype: torch.dtype,
|
||||
**extra_weight_attrs):
|
||||
weight = Parameter(torch.empty(sum(output_partition_sizes),
|
||||
@@ -179,7 +179,8 @@ class LinearBase(torch.nn.Module):
|
||||
self.quant_method = quant_config.get_quant_method(self,
|
||||
prefix=prefix)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
def forward(self,
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -240,9 +241,8 @@ class ReplicatedLinear(LinearBase):
|
||||
assert param.size() == loaded_weight.size()
|
||||
param.data.copy_(loaded_weight)
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
def forward(self,
|
||||
x: torch.Tensor) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
@@ -288,7 +288,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
output_sizes: Optional[List[int]] = None,
|
||||
output_sizes: Optional[list[int]] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config, prefix)
|
||||
@@ -374,7 +374,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
loaded_weight = loaded_weight.reshape(1)
|
||||
param.load_column_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
@@ -422,7 +422,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
|
||||
def __init__(self,
|
||||
input_size: int,
|
||||
output_sizes: List[int],
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
gather_output: bool = False,
|
||||
skip_bias_add: bool = False,
|
||||
@@ -500,7 +500,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
current_shard_offset = 0
|
||||
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
|
||||
False)
|
||||
shard_offsets: List[Tuple[int, int, int]] = []
|
||||
shard_offsets: list[tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
@@ -602,7 +602,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
||||
"""
|
||||
|
||||
current_shard_offset = 0
|
||||
shard_offsets: List[Tuple[int, int, int]] = []
|
||||
shard_offsets: list[tuple[int, int, int]] = []
|
||||
for i, output_size in enumerate(self.output_sizes):
|
||||
shard_offsets.append((i, current_shard_offset, output_size))
|
||||
current_shard_offset += output_size
|
||||
@@ -1124,7 +1124,7 @@ class RowParallelLinear(LinearBase):
|
||||
|
||||
param.load_row_parallel_weight(loaded_weight=loaded_weight)
|
||||
|
||||
def forward(self, input_):
|
||||
def forward(self, input_) -> tuple[torch.Tensor, Optional[Parameter]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user