[Bugfix] Fix fully sharded LoRAs with Mixtral (#11390)

Signed-off-by: Jason Greene <jason.greene@redhat.com>
This commit is contained in:
Jason T. Greene
2024-12-22 09:25:10 -06:00
committed by GitHub
parent 72d9c316d3
commit f1d1bf6288
2 changed files with 5 additions and 2 deletions

View File

@@ -425,8 +425,9 @@ class ReplicatedLinearWithLoRA(BaseLinearLayerWithLoRA):
if self.base_layer.skip_bias_add else None)
return output, output_bias
# ReplicatedLinear should always be replaced, regardless of the fully
# sharded LoRAs setting, because it is, by definition, copied per GPU.
@classmethod
@_not_fully_sharded_can_replace
def can_replace_layer(
cls,
source_layer: nn.Module,