[mypy][5/N] Support all typing on model executor (#4427)

This commit is contained in:
SangBin Cho
2024-04-29 11:01:26 +09:00
committed by GitHub
parent 03dd7d52bf
commit df29793dc7
10 changed files with 61 additions and 34 deletions

View File

@@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method = UnquantizedLinearMethod()
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self)
@@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
quant_config)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self, self.input_size,
[self.output_size], self.input_size,
self.output_size, self.params_dtype)
@@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
@@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
self.output_size_per_partition = divide(output_size, tp_size)
if output_sizes is None:
output_sizes = [output_size]
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size,
[x // tp_size for x in output_sizes],
@@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
bias = self.bias if not self.skip_bias_add else None
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_, bias)
if self.gather_output:
# All-gather across the partitions.
@@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
# Divide the weight matrix along the last dimension.
self.tp_size = get_tensor_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, self.tp_size)
# All the linear layer supports quant method.
assert self.quant_method is not None
self.quant_method.create_weights(self,
self.input_size_per_partition,
[self.output_size],
@@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
output_parallel = self.quant_method.apply(self, input_parallel)
if self.reduce_results and self.tp_size > 1:
output_ = tensor_model_parallel_all_reduce(output_parallel)