[mypy] Forward pass function type hints in lora (#11740)

Signed-off-by: lucast2021 <lucast2021@headroyce.org>
Co-authored-by: lucast2021 <lucast2021@headroyce.org>
This commit is contained in:
Lucas Tucker
2025-01-06 01:59:36 -06:00
committed by GitHub
parent 022c5c6944
commit 9c749713f6
3 changed files with 14 additions and 5 deletions

View File

@@ -238,7 +238,9 @@ class ReplicatedLinear(LinearBase):
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
def forward(self, x: torch.Tensor) -> torch.Tensor:
def forward(
self, x: torch.Tensor
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
bias = self.bias if not self.skip_bias_add else None
assert self.quant_method is not None
output = self.quant_method.apply(self, x, bias)