Add LoRA support for Mixtral (#2831)
* add mixtral lora support * formatting * fix incorrectly ported logic * polish tests * minor fixes and refactoring * minor fixes * formatting * rename and remove redundant logic * refactoring * refactoring * minor fix * minor refactoring * fix code smell
This commit is contained in:
@@ -11,25 +11,35 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager,
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, LoRAMapping)
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager,
|
||||
WorkerLoRAManager)
|
||||
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||
|
||||
EMBEDDING_MODULES = {
|
||||
"embed_tokens": "input_embeddings",
|
||||
"lm_head": "output_embeddings",
|
||||
}
|
||||
|
||||
EMBEDDING_PADDING_MODULES = ["lm_head"]
|
||||
|
||||
|
||||
def test_from_lora_tensors(sql_lora_files):
|
||||
tensors = load_file(
|
||||
os.path.join(sql_lora_files, "adapter_model.safetensors"))
|
||||
new_embeddings = load_file(
|
||||
os.path.join(sql_lora_files, "new_embeddings.safetensors"))
|
||||
lora_model = LoRAModel.from_lora_tensors(1,
|
||||
8,
|
||||
16,
|
||||
tensors,
|
||||
"cuda",
|
||||
embeddings=new_embeddings)
|
||||
lora_model = LoRAModel.from_lora_tensors(
|
||||
1,
|
||||
8,
|
||||
16,
|
||||
tensors,
|
||||
"cuda",
|
||||
embeddings=new_embeddings,
|
||||
embedding_modules=EMBEDDING_MODULES,
|
||||
embedding_padding_modules=EMBEDDING_PADDING_MODULES)
|
||||
for module_name, lora in lora_model.loras.items():
|
||||
assert lora.module_name == module_name
|
||||
assert lora.rank == 8
|
||||
@@ -90,14 +100,11 @@ def create_packed_lora(
|
||||
|
||||
def test_replace_submodules(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
manager = LoRAModelManager(model,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
LoRAConfig(max_lora_rank=8,
|
||||
max_cpu_loras=8,
|
||||
max_loras=8),
|
||||
lora_target_modules=["dense1", "layer1.dense2"])
|
||||
model.supported_lora_modules = ["dense1", "layer1.dense2"]
|
||||
model.packed_modules_mapping = {}
|
||||
manager = LoRAModelManager(
|
||||
model, 1, 1, 1,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=8, max_loras=8))
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("dense1"),
|
||||
@@ -111,16 +118,14 @@ def test_replace_submodules(dist_init, dummy_model):
|
||||
|
||||
def test_lora_model_manager(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.activate_lora(1)
|
||||
@@ -159,16 +164,14 @@ def test_lora_model_manager(dist_init, dummy_model):
|
||||
|
||||
def test_lora_lru_cache_model_manager(dist_init, dummy_model):
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2),
|
||||
lora_target_modules=["dense1", "dense2", "lm_head"])
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2))
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
assert manager.add_lora(model_lora1)
|
||||
assert manager.activate_lora(1)
|
||||
@@ -212,14 +215,15 @@ def test_lru_lora_model_manager(dist_init, dummy_model):
|
||||
# This tests just the LRU cache functionality, everything else is
|
||||
# tested in test_lora_model_manager
|
||||
model = dummy_model
|
||||
model.supported_lora_modules = ["dense1", "dense2", "lm_head"]
|
||||
model.packed_modules_mapping = {}
|
||||
model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"])
|
||||
model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"])
|
||||
model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"])
|
||||
manager = LRUCacheLoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||
["dense1", "dense2", "lm_head"])
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||
|
||||
assert all(x is None for x in manager.lora_index_to_id)
|
||||
|
||||
@@ -289,8 +293,9 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||
sql_lora_files):
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_lora_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||
torch.device("cuda"))
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
@@ -362,8 +367,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||
# Should remove every LoRA not specified in the request.
|
||||
lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4)
|
||||
worker_lora_manager = WorkerLoRAManager(
|
||||
4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config,
|
||||
torch.device("cuda"))
|
||||
4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"),
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
@@ -428,6 +434,13 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings,
|
||||
|
||||
def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||
model = dummy_model_gate_up
|
||||
model.supported_lora_modules = ["gate_up_proj"]
|
||||
model.packed_modules_mapping = {
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
model_lora = create_packed_lora(
|
||||
1,
|
||||
model,
|
||||
@@ -443,8 +456,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up):
|
||||
|
||||
manager = LoRAModelManager(
|
||||
model, 2, 2, 2,
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2),
|
||||
["gate_up_proj"])
|
||||
LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2))
|
||||
model = manager.model
|
||||
|
||||
assert isinstance(model.get_submodule("gate_up_proj"),
|
||||
|
||||
Reference in New Issue
Block a user