[Misc] Delete unused LoRA modules (#13151)
This commit is contained in:
@@ -5,7 +5,8 @@ import math
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Type, Union
|
||||
from typing import (Any, Callable, Dict, List, Optional, Sequence, Set, Type,
|
||||
Union)
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
@@ -619,12 +620,14 @@ class LoRAModelManager(AdapterModelManager):
|
||||
def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None:
|
||||
for module_name, new_module_names in self.packed_modules.items():
|
||||
replacement_loras: List[Optional[LoRALayerWeights]] = []
|
||||
replaced_module: Set[str] = set()
|
||||
has_replacement = False
|
||||
for r in new_module_names:
|
||||
lora = lora_model.get_lora(r)
|
||||
replacement_loras.append(lora)
|
||||
if lora:
|
||||
has_replacement = True
|
||||
replaced_module.add(r)
|
||||
if not has_replacement:
|
||||
continue
|
||||
for i in range(len(replacement_loras)):
|
||||
@@ -633,6 +636,9 @@ class LoRAModelManager(AdapterModelManager):
|
||||
replacement_loras[i] = None
|
||||
lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
|
||||
replacement_loras)
|
||||
# Remove the modules that have been replaced.
|
||||
for module in replaced_module:
|
||||
lora_model.loras.pop(module, None)
|
||||
|
||||
def deactivate_adapter(self, adapter_id: int) -> bool:
|
||||
return deactivate_adapter(adapter_id, self._active_adapters,
|
||||
|
||||
@@ -147,7 +147,7 @@ class PunicaWrapperBase(PunicaWrapperABC):
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# 5 is the number of indicies tensors.
|
||||
# 5 is the number of indices tensors.
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices,long_lora_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 5
|
||||
|
||||
Reference in New Issue
Block a user