[Core] Optimize LoRA weight loading (#25403)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-09-23 18:19:45 +08:00
committed by GitHub
parent 231c2c63e4
commit 273690a50a
10 changed files with 83 additions and 83 deletions

View File

@@ -164,8 +164,8 @@ def populate_loras(
weight=layer_weights,
generate_embeddings_tensor=generate_embeddings_tensor,
)
sublora.lora_b = sublora.lora_b[:, (sublora_len *
i):(sublora_len * (i + 1))]
sublora.lora_b = sublora.lora_b[(sublora_len *
i):(sublora_len * (i + 1)), :]
sublora.optimize()
subloras.append(sublora)
@@ -304,9 +304,9 @@ def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
result = embedding(input_)
after_a = F.embedding(
input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)
@@ -445,9 +445,9 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
result = expanded_embedding(input_)
after_a = F.embedding(
original_input_,
lora.lora_a,
lora.lora_a.T,
)
result += (after_a @ lora.lora_b)
result += (after_a @ lora.lora_b.T)
expected_results.append(result)
expected_result = torch.cat(expected_results)
@@ -575,7 +575,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
lm_head=linear,
embedding_bias=None)
result[:, vocab_size + embeddings_tensor_len:] = float("-inf")
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
logits_processor.org_vocab_size = vocab_size
@@ -692,9 +692,10 @@ def test_linear_replicated(
expected_results: list[torch.Tensor] = []
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
@@ -817,7 +818,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
for input_, lora_id in zip(inputs, prompt_mapping):
lora = lora_dict[lora_id]
result = linear(input_)[0]
result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling
result += input_ @ lora.lora_a.T @ lora.lora_b.T * lora.scaling
expected_results.append(result)
expected_result = torch.cat(expected_results)
@@ -965,9 +966,10 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
result = linear(input_)[0]
subloras = sublora_dict[lora_id]
for i, sublora in enumerate(subloras):
result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] *
(i + 1)] += (input_ @ sublora.lora_a @ sublora.lora_b *
sublora.scaling)
result[:, sublora.lora_b.shape[0] * i:sublora.lora_b.shape[0] *
(i + 1)] += (
input_ @ sublora.lora_a.T @ sublora.lora_b.T *
sublora.scaling)
expected_results.append(result)
expected_result = torch.cat(expected_results)