[Core] Optimize LoRA weight loading (#25403)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -36,10 +36,10 @@ class DummyLoRAManager:
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=torch.rand([weight.shape[1], rank],
|
||||
lora_a=torch.rand([rank, weight.shape[1]],
|
||||
dtype=weight.dtype,
|
||||
device=self._device),
|
||||
lora_b=torch.rand([rank, weight.shape[0]],
|
||||
lora_b=torch.rand([weight.shape[0], rank],
|
||||
dtype=weight.dtype,
|
||||
device=self._device),
|
||||
)
|
||||
@@ -67,8 +67,8 @@ class DummyLoRAManager:
|
||||
module_name,
|
||||
rank=rank,
|
||||
lora_alpha=1,
|
||||
lora_a=torch.rand([input_dim, rank], device="cuda"),
|
||||
lora_b=torch.rand([rank, output_dim], device="cuda"),
|
||||
lora_a=torch.rand([rank, input_dim], device="cuda"),
|
||||
lora_b=torch.rand([output_dim, input_dim], device="cuda"),
|
||||
embeddings_tensor=embeddings_tensor,
|
||||
)
|
||||
self.set_module_lora(module_name, lora)
|
||||
|
||||
Reference in New Issue
Block a user