[Core] Optimize LoRA weight loading (#25403)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-09-23 18:19:45 +08:00
committed by GitHub
parent 231c2c63e4
commit 273690a50a
10 changed files with 83 additions and 83 deletions

View File

@@ -63,9 +63,9 @@ def test_from_lora_tensors(sql_lora_files, device):
assert lora.lora_b is not None
assert lora.lora_a.device == torch.device(device)
assert lora.lora_b.device == torch.device(device)
assert (lora.lora_a.shape[1] == lora.lora_b.shape[0]
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
assert lora.lora_a.shape[1] == 8
assert lora.lora_a.shape[0] == 8
embeddings_module = next(
(k for k in EMBEDDING_MODULES if k in module_name), None)
if embeddings_module:
@@ -86,8 +86,8 @@ def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0]], device=device),
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0], 8], device=device),
)
return LoRAModel(lora_id, 8, loras)
@@ -109,8 +109,8 @@ def create_packed_lora(
replaced_module_name,
8,
16,
torch.rand([w.shape[1], 8], device=device),
torch.rand([8, w.shape[0] // len(replaced_module_names)],
torch.rand([8, w.shape[1]], device=device),
torch.rand([w.shape[0] // len(replaced_module_names), 8],
device=device),
)
return LoRAModel(lora_id, 8, loras)