[Core] Change LoRA embedding sharding to support loading methods (#5038)
This commit is contained in:
@@ -102,22 +102,21 @@ def batched_generate(
|
||||
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.fixture(scope="module")
|
||||
def lora_llm(long_context_infos):
|
||||
scaling_factors = [
|
||||
context_len_to_scaling_factor[info["context_length"]]
|
||||
for info in long_context_infos.values()
|
||||
]
|
||||
|
||||
llm = vllm.LLM(
|
||||
"meta-llama/Llama-2-13b-chat-hf",
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=2,
|
||||
long_lora_scaling_factors=tuple(scaling_factors),
|
||||
max_num_batched_tokens=4096 * 8,
|
||||
tensor_parallel_size=4,
|
||||
)
|
||||
llm = vllm.LLM("meta-llama/Llama-2-13b-chat-hf",
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=2,
|
||||
long_lora_scaling_factors=tuple(scaling_factors),
|
||||
max_num_batched_tokens=4096 * 8,
|
||||
tensor_parallel_size=4,
|
||||
distributed_executor_backend="mp")
|
||||
yield llm
|
||||
del llm
|
||||
|
||||
@@ -154,6 +153,7 @@ def test_rotary_emb_replaced(dist_init):
|
||||
assert rotary_emb_count == 32
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
"""We test the batched kernel by comparing the results of batched an
|
||||
non-batched generation.
|
||||
@@ -188,6 +188,7 @@ def test_batched_rope_kernel(lora_llm, long_context_infos):
|
||||
f"same:\n{batched}\n{non_batched}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_self_consistency(lora_llm, long_context_infos):
|
||||
"""We test consistency of the batched kernel by permuting batched
|
||||
inputs and comparing the results to the non-permuted batched results.
|
||||
@@ -227,6 +228,7 @@ def test_self_consistency(lora_llm, long_context_infos):
|
||||
f"\n{permutated_batched_results[permutation[i]]}")
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_quality(lora_llm, long_context_infos):
|
||||
"""We test the quality of the answers given by the LoRA model by
|
||||
comparing the generated text to the merged model's outputs.
|
||||
@@ -257,6 +259,7 @@ def test_quality(lora_llm, long_context_infos):
|
||||
assert np.mean(scores) > 0.5
|
||||
|
||||
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_max_len(lora_llm, long_context_infos):
|
||||
"""Test that we raise an ValueError when the input of a given LoRA
|
||||
model exceeds the maximum length."""
|
||||
|
||||
Reference in New Issue
Block a user