[Core] Change LoRA embedding sharding to support loading methods (#5038)

This commit is contained in:
Antoni Baum
2024-06-06 19:07:57 -07:00
committed by GitHub
parent a31cab7556
commit ccdc490dda
11 changed files with 658 additions and 126 deletions

View File

@@ -102,22 +102,21 @@ def batched_generate(
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
@pytest.fixture
@pytest.fixture(scope="module")
def lora_llm(long_context_infos):
scaling_factors = [
context_len_to_scaling_factor[info["context_length"]]
for info in long_context_infos.values()
]
llm = vllm.LLM(
"meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
)
llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf",
enable_lora=True,
max_num_seqs=16,
max_loras=2,
long_lora_scaling_factors=tuple(scaling_factors),
max_num_batched_tokens=4096 * 8,
tensor_parallel_size=4,
distributed_executor_backend="mp")
yield llm
del llm
@@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init):
assert rotary_emb_count == 32
@pytest.mark.skip_global_cleanup
def test_batched_rope_kernel(lora_llm, long_context_infos):
"""We test the batched kernel by comparing the results of batched an
non-batched generation.
@@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
f"same:\n{batched}\n{non_batched}")
@pytest.mark.skip_global_cleanup
def test_self_consistency(lora_llm, long_context_infos):
"""We test consistency of the batched kernel by permuting batched
inputs and comparing the results to the non-permuted batched results.
@@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos):
f"\n{permutated_batched_results[permutation[i]]}")
@pytest.mark.skip_global_cleanup
def test_quality(lora_llm, long_context_infos):
"""We test the quality of the answers given by the LoRA model by
comparing the generated text to the merged model's outputs.
@@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos):
assert np.mean(scores) > 0.5
@pytest.mark.skip_global_cleanup
def test_max_len(lora_llm, long_context_infos):
"""Test that we raise an ValueError when the input of a given LoRA
model exceeds the maximum length."""