[mypy][5/N] Support all typing on model executor (#4427)
This commit is contained in:
@@ -128,7 +128,8 @@ class LinearBase(torch.nn.Module):
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
if quant_config is None:
|
||||
self.quant_method = UnquantizedLinearMethod()
|
||||
self.quant_method: Optional[
|
||||
QuantizeMethodBase] = UnquantizedLinearMethod()
|
||||
else:
|
||||
self.quant_method = quant_config.get_quant_method(self)
|
||||
|
||||
@@ -160,6 +161,8 @@ class ReplicatedLinear(LinearBase):
|
||||
super().__init__(input_size, output_size, skip_bias_add, params_dtype,
|
||||
quant_config)
|
||||
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self, self.input_size,
|
||||
[self.output_size], self.input_size,
|
||||
self.output_size, self.params_dtype)
|
||||
@@ -173,6 +176,7 @@ class ReplicatedLinear(LinearBase):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
assert self.quant_method is not None
|
||||
output = self.quant_method.apply(self, x, bias)
|
||||
output_bias = self.bias if self.skip_bias_add else None
|
||||
return output, output_bias
|
||||
@@ -221,6 +225,8 @@ class ColumnParallelLinear(LinearBase):
|
||||
self.output_size_per_partition = divide(output_size, tp_size)
|
||||
if output_sizes is None:
|
||||
output_sizes = [output_size]
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size,
|
||||
[x // tp_size for x in output_sizes],
|
||||
@@ -255,6 +261,7 @@ class ColumnParallelLinear(LinearBase):
|
||||
bias = self.bias if not self.skip_bias_add else None
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_, bias)
|
||||
if self.gather_output:
|
||||
# All-gather across the partitions.
|
||||
@@ -579,6 +586,8 @@ class RowParallelLinear(LinearBase):
|
||||
# Divide the weight matrix along the last dimension.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.input_size_per_partition = divide(input_size, self.tp_size)
|
||||
# All the linear layer supports quant method.
|
||||
assert self.quant_method is not None
|
||||
self.quant_method.create_weights(self,
|
||||
self.input_size_per_partition,
|
||||
[self.output_size],
|
||||
@@ -624,6 +633,7 @@ class RowParallelLinear(LinearBase):
|
||||
input_parallel = splitted_input[tp_rank].contiguous()
|
||||
|
||||
# Matrix multiply.
|
||||
assert self.quant_method is not None
|
||||
output_parallel = self.quant_method.apply(self, input_parallel)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
output_ = tensor_model_parallel_all_reduce(output_parallel)
|
||||
|
||||
Reference in New Issue
Block a user