Improve configs - LoRAConfig + PromptAdapterConfig (#16980)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-24 18:29:34 +01:00
committed by GitHub
parent 0422ce109f
commit 0fa939e2d1
3 changed files with 130 additions and 91 deletions

View File

@@ -31,6 +31,8 @@ DEVICES = ([
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
] if current_platform.is_cuda_alike() else ["cpu"])
DEFAULT_DTYPE = torch.get_default_dtype()
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
@@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
model = dummy_model
manager = LoRAModelManager(
model, 1, 1, 1,
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
torch.device(DEVICES[0]))
LoRAConfig(max_lora_rank=8,
max_cpu_loras=8,
max_loras=8,
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
model = manager.model
assert isinstance(model.get_submodule("dense1"),
ColumnParallelLinearWithLoRA)
@@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
@@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=3,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
assert manager.add_adapter(model_lora1)
@@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
assert all(x is None for x in manager.lora_index_to_id)
@@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
@pytest.mark.parametrize("device", DEVICES)
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = LRUCacheWorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
@@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
sql_lora_files, device):
# Should remove every LoRA not specified in the request.
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
lora_config = LoRAConfig(max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE)
worker_adapter_manager = WorkerLoRAManager(
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
lora_config.lora_extra_vocab_size, lora_config, device,
@@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
2,
LoRAConfig(max_lora_rank=8,
max_cpu_loras=2,
max_loras=2),
max_loras=2,
lora_dtype=DEFAULT_DTYPE),
device=device)
model = manager.model