[Core] Initialize LoRA support for tower and connector in multi-modal models (#26674)
Signed-off-by: bk-201 <joy25810@foxmail.com> Signed-off-by: Jee Jee Li <pandaleefree@gmail.com> Signed-off-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: bk-201 <joy25810@foxmail.com> Co-authored-by: prashanth058 <prashanth.dannamaneni@uipath.com> Co-authored-by: Anexdeus <5142168@mail.ru>
This commit is contained in:
@@ -18,6 +18,7 @@ from vllm.lora.layers import (
|
||||
from vllm.lora.lora_model import LoRAModel
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.model_manager import (
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY,
|
||||
LoRAMapping,
|
||||
LoRAModelManager,
|
||||
LRUCacheLoRAModelManager,
|
||||
@@ -183,9 +184,11 @@ def test_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.activate_adapter(2)
|
||||
assert manager.lora_index_to_id[0] == 3
|
||||
assert manager.lora_index_to_id[1] == 2
|
||||
|
||||
assert manager.device == device
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert hasattr(manager, "supported_lora_modules")
|
||||
assert sorted(manager.supported_lora_modules) == [
|
||||
"dense1",
|
||||
@@ -278,8 +281,10 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_adapter(3)
|
||||
with pytest.raises(ValueError):
|
||||
assert manager.pin_adapter(3)
|
||||
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@@ -402,7 +407,10 @@ def test_lru_lora_model_manager(dist_init, dummy_model, device):
|
||||
assert manager.remove_oldest_adapter()
|
||||
|
||||
assert set(manager.list_adapters()) == {1}
|
||||
assert manager.punica_wrapper.device == device
|
||||
assert (
|
||||
manager.punica_wrapper_mapping.get(DEFAULT_LANGUAGE_WRAPPER_KEY).device
|
||||
== device
|
||||
)
|
||||
assert manager.device == device
|
||||
|
||||
|
||||
@@ -514,7 +522,10 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device, tmp_pa
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
@@ -618,7 +629,10 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device, tmp_path
|
||||
)
|
||||
|
||||
assert worker_adapter_manager.device == device
|
||||
assert worker_adapter_manager._adapter_manager.punica_wrapper.device == device
|
||||
punica_wrapper = worker_adapter_manager._adapter_manager.punica_wrapper_mapping.get(
|
||||
DEFAULT_LANGUAGE_WRAPPER_KEY
|
||||
)
|
||||
assert punica_wrapper.device == device
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
|
||||
Reference in New Issue
Block a user