[Core] Optimize LoRA weight loading (#25403)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
|
||||
assert lora.lora_b is not None
|
||||
assert lora.lora_a.device == torch.device(device)
|
||||
assert lora.lora_b.device == torch.device(device)
|
||||
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
|
||||
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
|
||||
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||
assert lora.lora_a.shape[1] == 8
|
||||
assert lora.lora_a.shape[0] == 8
|
||||
embeddings_module = next(
|
||||
(k for k in EMBEDDING_MODULES if k in module_name), None)
|
||||
if embeddings_module:
|
||||
@@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
|
||||
name,
|
||||
8,
|
||||
16,
|
||||
torch.rand([w.shape[1], 8], device=device),
|
||||
torch.rand([8, w.shape[0]], device=device),
|
||||
torch.rand([8, w.shape[1]], device=device),
|
||||
torch.rand([w.shape[0], 8], device=device),
|
||||
)
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
@@ -109,8 +109,8 @@ def create_packed_lora(
|
||||
replaced_module_name,
|
||||
8,
|
||||
16,
|
||||
torch.rand([w.shape[1], 8], device=device),
|
||||
torch.rand([8, w.shape[0] // len(replaced_module_names)],
|
||||
torch.rand([8, w.shape[1]], device=device),
|
||||
torch.rand([w.shape[0] // len(replaced_module_names), 8],
|
||||
device=device),
|
||||
)
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
Reference in New Issue
Block a user