Improve configs - LoRAConfig + PromptAdapterConfig (#16980)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -31,6 +31,8 @@ DEVICES = ([
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
] if current_platform.is_cuda_alike() else ["cpu"])
|
||||
|
||||
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def use_v0_only(monkeypatch: pytest.MonkeyPatch):
|
||||
@@ -125,8 +127,10 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8),
|
||||
torch.device(DEVICES[0]))
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=8,
|
||||
max_loras=8,
|
||||
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
|
||||
model = manager.model
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
ColumnParallelLinearWithLoRA)
|
||||
@@ -155,7 +159,8 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
@@ -221,7 +226,8 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
@@ -316,7 +322,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
@@ -424,7 +431,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
@@ -504,7 +514,10 @@ def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files, device):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
@@ -600,7 +613,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2),
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
model = manager.model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user