Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,16 +10,21 @@ from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.lora.layers import (
|
||||
ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
)
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager)
|
||||
from vllm.lora.models import (
|
||||
LoRAMapping,
|
||||
LoRAModel,
|
||||
LoRAModelManager,
|
||||
LRUCacheLoRAModelManager,
|
||||
)
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .utils import create_peft_lora
|
||||
@@ -31,22 +36,25 @@ EMBEDDING_MODULES = {
|
||||
|
||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||
|
||||
DEVICES = ([
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
] if current_platform.is_cuda_alike() else ["cpu"])
|
||||
DEVICES = (
|
||||
[f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
|
||||
if current_platform.is_cuda_alike()
|
||||
else ["cpu"]
|
||||
)
|
||||
|
||||
DEFAULT_DTYPE = torch.get_default_dtype()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_from_lora_tensors(sql_lora_files, device):
|
||||
tensors = load_file(
|
||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
tensors = load_file(os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
new_embeddings = load_file(
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors")
|
||||
)
|
||||
|
||||
peft_helper = PEFTHelper.from_local_dir(sql_lora_files,
|
||||
max_position_embeddings=4096)
|
||||
peft_helper = PEFTHelper.from_local_dir(
|
||||
sql_lora_files, max_position_embeddings=4096
|
||||
)
|
||||
lora_model = LoRAModel.from_lora_tensors(
|
||||
1,
|
||||
tensors,
|
||||
@@ -54,7 +62,8 @@ def test_from_lora_tensors(sql_lora_files, device):
|
||||
device=device,
|
||||
embeddings=new_embeddings,
|
||||
embedding_modules=EMBEDDING_MODULES,
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES,
|
||||
)
|
||||
for module_name, lora in lora_model.loras.items():
|
||||
assert lora.module_name == module_name
|
||||
assert lora.rank == 8
|
||||
@@ -63,22 +72,27 @@ def test_from_lora_tensors(sql_lora_files, device):
|
||||
assert lora.lora_b is not None
|
||||
assert lora.lora_a.device == torch.device(device)
|
||||
assert lora.lora_b.device == torch.device(device)
|
||||
assert (lora.lora_a.shape[0] == lora.lora_b.shape[1]
|
||||
), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||
assert lora.lora_a.shape[0] == lora.lora_b.shape[1], (
|
||||
f"{lora.lora_a.shape=}, {lora.lora_b.shape=}"
|
||||
)
|
||||
assert lora.lora_a.shape[0] == 8
|
||||
embeddings_module = next(
|
||||
(k for k in EMBEDDING_MODULES if k in module_name), None)
|
||||
(k for k in EMBEDDING_MODULES if k in module_name), None
|
||||
)
|
||||
if embeddings_module:
|
||||
assert torch.equal(
|
||||
lora.embeddings_tensor,
|
||||
new_embeddings[EMBEDDING_MODULES[embeddings_module]].to(
|
||||
device=lora.embeddings_tensor.device))
|
||||
device=lora.embeddings_tensor.device
|
||||
),
|
||||
)
|
||||
else:
|
||||
assert lora.embeddings_tensor is None
|
||||
|
||||
|
||||
def create_lora(lora_id: int, model: nn.Module, sub_modules: list[str],
|
||||
device: torch.device) -> LoRAModel:
|
||||
def create_lora(
|
||||
lora_id: int, model: nn.Module, sub_modules: list[str], device: torch.device
|
||||
) -> LoRAModel:
|
||||
loras: dict[str, LoRALayerWeights] = {}
|
||||
for name in sub_modules:
|
||||
w = model.get_submodule(name).weight
|
||||
@@ -110,8 +124,7 @@ def create_packed_lora(
|
||||
8,
|
||||
16,
|
||||
torch.rand([8, w.shape[1]], device=device),
|
||||
torch.rand([w.shape[0] // len(replaced_module_names), 8],
|
||||
device=device),
|
||||
torch.rand([w.shape[0] // len(replaced_module_names), 8], device=device),
|
||||
)
|
||||
return LoRAModel(lora_id, 8, loras)
|
||||
|
||||
@@ -119,42 +132,42 @@ def create_packed_lora(
|
||||
def test_replace_submodules(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=8,
|
||||
max_loras=8,
|
||||
lora_dtype=DEFAULT_DTYPE), torch.device(DEVICES[0]))
|
||||
model,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE
|
||||
),
|
||||
torch.device(DEVICES[0]),
|
||||
)
|
||||
model = manager.model
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
ColumnParallelLinearWithLoRA)
|
||||
assert isinstance(model.get_submodule("layer1.dense1"),
|
||||
ColumnParallelLinearWithLoRA)
|
||||
assert isinstance(model.get_submodule("dense1"), ColumnParallelLinearWithLoRA)
|
||||
assert isinstance(
|
||||
model.get_submodule("layer1.dense1"), ColumnParallelLinearWithLoRA
|
||||
)
|
||||
assert isinstance(model.get_submodule("dense2"), RowParallelLinearWithLoRA)
|
||||
assert isinstance(model.get_submodule("layer1.dense2"),
|
||||
RowParallelLinearWithLoRA)
|
||||
assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora2 = create_lora(2,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora3 = create_lora(3,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
manager = LoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
model_lora1 = create_lora(
|
||||
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
|
||||
)
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
|
||||
manager = LoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.activate_adapter(1)
|
||||
@@ -204,24 +217,21 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
model = dummy_model
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora2 = create_lora(2,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora3 = create_lora(3,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
manager = LRUCacheLoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=3,
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
model_lora1 = create_lora(
|
||||
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
|
||||
)
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=3, max_loras=2, lora_dtype=DEFAULT_DTYPE
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_adapter(model_lora1)
|
||||
assert manager.activate_adapter(1)
|
||||
@@ -297,27 +307,22 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
# This tests just the LRU cache functionality, everything else is
|
||||
# tested in test_lora_model_manager
|
||||
model = dummy_model
|
||||
model_lora1 = create_lora(1,
|
||||
model, ["layer1.dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora2 = create_lora(2,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora3 = create_lora(3,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
model_lora4 = create_lora(4,
|
||||
model, ["dense1", "dense2", "lm_head"],
|
||||
device=device)
|
||||
manager = LRUCacheLoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
model_lora1 = create_lora(
|
||||
1, model, ["layer1.dense1", "dense2", "lm_head"], device=device
|
||||
)
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"], device=device)
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"], device=device)
|
||||
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"], device=device)
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
# Add up to capacity
|
||||
@@ -421,12 +426,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
tmp_path):
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_path):
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
|
||||
)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
@@ -438,13 +441,13 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config,
|
||||
lora_config=lora_config)
|
||||
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
|
||||
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
|
||||
)
|
||||
|
||||
worker_adapter_manager.max_num_seqs = 4
|
||||
worker_adapter_manager.max_num_batched_tokens = 2
|
||||
@@ -452,52 +455,64 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
worker_adapter_manager.create_lora_manager(dummy_model)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7
|
||||
@@ -506,41 +521,40 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
|
||||
device)
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
tmp_path):
|
||||
def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path):
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
|
||||
)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config,
|
||||
lora_config=lora_config)
|
||||
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
|
||||
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
|
||||
worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
|
||||
EMBEDDING_MODULES,
|
||||
EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES
|
||||
)
|
||||
worker_adapter_manager.vocab_size = (
|
||||
dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size)
|
||||
dummy_model_gate_up.unpadded_vocab_size - lora_config.lora_extra_vocab_size
|
||||
)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
@@ -553,49 +567,61 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[LoRARequest("1", 1, dummy_lora_files), LoRARequest("2", 2, dummy_lora_files)],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("3", 3, dummy_lora_files),
|
||||
LoRARequest("4", 4, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 3, 4}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("2", 2, dummy_lora_files),
|
||||
LoRARequest("5", 5, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1, 2, 5}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
LoRARequest("1", 1, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {1}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None
|
||||
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("6", 6, dummy_lora_files),
|
||||
LoRARequest("7", 7, dummy_lora_files),
|
||||
LoRARequest("8", 8, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
assert worker_adapter_manager.list_adapters() == {6, 7, 8}
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8
|
||||
assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6
|
||||
@@ -603,17 +629,19 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
|
||||
# Over capacity
|
||||
with pytest.raises(RuntimeError):
|
||||
worker_adapter_manager.set_active_adapters([
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files)
|
||||
], mapping)
|
||||
worker_adapter_manager.set_active_adapters(
|
||||
[
|
||||
LoRARequest("10", 10, dummy_lora_files),
|
||||
LoRARequest("11", 11, dummy_lora_files),
|
||||
LoRARequest("12", 12, dummy_lora_files),
|
||||
LoRARequest("13", 13, dummy_lora_files),
|
||||
LoRARequest("14", 14, dummy_lora_files),
|
||||
],
|
||||
mapping,
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert (worker_adapter_manager._adapter_manager.punica_wrapper.device ==
|
||||
device)
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@@ -624,7 +652,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
model,
|
||||
module_name="gate_up_proj",
|
||||
replaced_module_names=["gate_proj", "up_proj"],
|
||||
device=device)
|
||||
device=device,
|
||||
)
|
||||
model_lora1 = create_packed_lora(
|
||||
2,
|
||||
model,
|
||||
@@ -634,19 +663,21 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
empty_replaced_module_name="gate_proj",
|
||||
)
|
||||
|
||||
manager = LoRAModelManager(model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE),
|
||||
device=device)
|
||||
manager = LoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=2, max_loras=2, lora_dtype=DEFAULT_DTYPE
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||
MergedColumnParallelLinearWithLoRA)
|
||||
assert isinstance(
|
||||
model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA
|
||||
)
|
||||
# Verify packed lora is correct
|
||||
model_lora_clone = model_lora.clone(1)
|
||||
model_lora_clone1 = model_lora1.clone(1)
|
||||
@@ -659,21 +690,27 @@ def test_packed_loras(dist_init, dummy_model_gate_up, device):
|
||||
packed_lora = model_lora.get_lora("gate_up_proj")
|
||||
assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights)
|
||||
|
||||
torch.testing.assert_close(packed_lora.lora_a[0],
|
||||
model_lora_clone.get_lora("gate_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora.lora_b[0],
|
||||
model_lora_clone.get_lora("gate_proj").lora_b)
|
||||
torch.testing.assert_close(packed_lora.lora_a[1],
|
||||
model_lora_clone.get_lora("up_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora.lora_b[1],
|
||||
model_lora_clone.get_lora("up_proj").lora_b)
|
||||
torch.testing.assert_close(
|
||||
packed_lora.lora_a[0], model_lora_clone.get_lora("gate_proj").lora_a
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
packed_lora.lora_b[0], model_lora_clone.get_lora("gate_proj").lora_b
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
packed_lora.lora_a[1], model_lora_clone.get_lora("up_proj").lora_a
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
packed_lora.lora_b[1], model_lora_clone.get_lora("up_proj").lora_b
|
||||
)
|
||||
|
||||
packed_lora1 = model_lora1.get_lora("gate_up_proj")
|
||||
assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights)
|
||||
|
||||
assert packed_lora1.lora_a[0] is None
|
||||
assert packed_lora1.lora_b[0] is None
|
||||
torch.testing.assert_close(packed_lora1.lora_a[1],
|
||||
model_lora_clone1.get_lora("up_proj").lora_a)
|
||||
torch.testing.assert_close(packed_lora1.lora_b[1],
|
||||
model_lora_clone1.get_lora("up_proj").lora_b)
|
||||
torch.testing.assert_close(
|
||||
packed_lora1.lora_a[1], model_lora_clone1.get_lora("up_proj").lora_a
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user