[Misc][LoRA] Add --lora-target-modules to restrict LoRA to specific modules (#34984)
Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic
|
||||
torch.testing.assert_close(
|
||||
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
|
||||
)
|
||||
|
||||
|
||||
def _test_target_modules(
|
||||
model,
|
||||
target_modules: list[str] | None,
|
||||
device: str,
|
||||
expected_lora: list[tuple[str, type]],
|
||||
expected_no_lora: list[tuple[str, type]],
|
||||
):
|
||||
"""Create a LoRAModelManager and assert which modules have LoRA applied."""
|
||||
LoRAModelManager(
|
||||
model,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
LoRAConfig(
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=2,
|
||||
max_loras=2,
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
target_modules=target_modules,
|
||||
),
|
||||
device=device,
|
||||
)
|
||||
for module_path, lora_cls in expected_lora:
|
||||
assert isinstance(model.get_submodule(module_path), lora_cls)
|
||||
for module_path, lora_cls in expected_no_lora:
|
||||
assert not isinstance(model.get_submodule(module_path), lora_cls)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
|
||||
"""Test that target_modules config restricts which modules get LoRA applied."""
|
||||
_test_target_modules(
|
||||
dummy_model,
|
||||
["dense1"],
|
||||
device,
|
||||
expected_lora=[
|
||||
("dense1", ColumnParallelLinearWithLoRA),
|
||||
("layer1.dense1", ColumnParallelLinearWithLoRA),
|
||||
],
|
||||
expected_no_lora=[
|
||||
("dense2", RowParallelLinearWithLoRA),
|
||||
("layer1.dense2", RowParallelLinearWithLoRA),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
|
||||
"""Test that multiple target_modules work correctly."""
|
||||
_test_target_modules(
|
||||
dummy_model,
|
||||
["dense1", "dense2"],
|
||||
device,
|
||||
expected_lora=[
|
||||
("dense1", ColumnParallelLinearWithLoRA),
|
||||
("layer1.dense1", ColumnParallelLinearWithLoRA),
|
||||
("dense2", RowParallelLinearWithLoRA),
|
||||
("layer1.dense2", RowParallelLinearWithLoRA),
|
||||
],
|
||||
expected_no_lora=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_target_modules_none_uses_all(
|
||||
default_vllm_config, dist_init, dummy_model, device
|
||||
):
|
||||
"""Test that target_modules=None uses all supported modules."""
|
||||
_test_target_modules(
|
||||
dummy_model,
|
||||
None,
|
||||
device,
|
||||
expected_lora=[
|
||||
("dense1", ColumnParallelLinearWithLoRA),
|
||||
("layer1.dense1", ColumnParallelLinearWithLoRA),
|
||||
("dense2", RowParallelLinearWithLoRA),
|
||||
("layer1.dense2", RowParallelLinearWithLoRA),
|
||||
],
|
||||
expected_no_lora=[],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_load_adapter_warns_on_unsupported_modules(
|
||||
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
|
||||
):
|
||||
"""Test that _load_adapter warns when a LoRA adapter contains modules
|
||||
not in the model's supported LoRA target modules."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import vllm.lora.worker_manager as wm_module
|
||||
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
|
||||
)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model_gate_up,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
|
||||
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
|
||||
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
|
||||
worker_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
# Patch from_local_checkpoint to inject an unsupported module
|
||||
original_from_checkpoint = LoRAModel.from_local_checkpoint
|
||||
|
||||
def patched_from_checkpoint(*args, **kwargs):
|
||||
lora = original_from_checkpoint(*args, **kwargs)
|
||||
lora.loras["unsupported_module"] = LoRALayerWeights(
|
||||
module_name="unsupported_module",
|
||||
rank=8,
|
||||
lora_alpha=16,
|
||||
lora_a=torch.randn(8, 10),
|
||||
lora_b=torch.randn(10, 8),
|
||||
)
|
||||
return lora
|
||||
|
||||
lora_request = LoRARequest("test", 1, dummy_lora_files)
|
||||
with (
|
||||
patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
|
||||
patch.object(wm_module.logger, "warning_once") as mock_warning,
|
||||
):
|
||||
worker_manager._load_adapter(lora_request)
|
||||
warning_args = mock_warning.call_args_list
|
||||
found = any("unsupported_module" in str(call) for call in warning_args)
|
||||
assert found, (
|
||||
f"Expected warning about 'unsupported_module', got: {warning_args}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("device", DEVICES)
|
||||
def test_load_adapter_warns_on_target_modules_restriction(
|
||||
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
|
||||
):
|
||||
"""Test that _load_adapter warns when a LoRA adapter contains modules
|
||||
excluded by the deployment-time target_modules restriction."""
|
||||
from unittest.mock import patch
|
||||
|
||||
import vllm.lora.worker_manager as wm_module
|
||||
|
||||
# Restrict to only dense2 — adapter has dense1 which will be excluded
|
||||
lora_config = LoRAConfig(
|
||||
max_lora_rank=8,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
target_modules=["dense2"],
|
||||
)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
os.makedirs(dummy_lora_files, exist_ok=True)
|
||||
create_peft_lora(
|
||||
dummy_model_gate_up,
|
||||
save_dir=dummy_lora_files,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
|
||||
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
|
||||
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
|
||||
worker_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
lora_request = LoRARequest("test", 1, dummy_lora_files)
|
||||
with patch.object(wm_module.logger, "warning_once") as mock_warning:
|
||||
worker_manager._load_adapter(lora_request)
|
||||
warning_args = mock_warning.call_args_list
|
||||
# dense1 is supported by the model but excluded by target_modules
|
||||
found = any("target_modules" in str(call) for call in warning_args)
|
||||
assert found, (
|
||||
f"Expected warning about target_modules restriction, got: {warning_args}"
|
||||
)
|
||||
|
||||
60
tests/lora/test_lora_utils.py
Normal file
60
tests/lora/test_lora_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
from vllm.lora.utils import is_in_target_modules, is_supported_lora_module
|
||||
|
||||
|
||||
class TestIsSupportedLoraModule:
|
||||
"""Tests for is_supported_lora_module (model-definition check)."""
|
||||
|
||||
def test_suffix_match(self):
|
||||
assert is_supported_lora_module(
|
||||
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
|
||||
)
|
||||
|
||||
def test_no_match(self):
|
||||
assert not is_supported_lora_module(
|
||||
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
|
||||
)
|
||||
|
||||
def test_exact_match(self):
|
||||
assert is_supported_lora_module("o_proj", ["o_proj"])
|
||||
|
||||
def test_regex_suffix_matching(self):
|
||||
"""Regex anchors to end — partial suffix should not match."""
|
||||
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", ["proj"])
|
||||
|
||||
def test_empty_supported_modules(self):
|
||||
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", [])
|
||||
|
||||
def test_multiple_supported_modules(self):
|
||||
supported = ["q_proj", "k_proj", "v_proj", "o_proj"]
|
||||
assert is_supported_lora_module("model.layers.0.self_attn.v_proj", supported)
|
||||
assert not is_supported_lora_module("model.layers.0.mlp.gate_proj", supported)
|
||||
|
||||
|
||||
class TestIsInTargetModules:
|
||||
"""Tests for is_in_target_modules (deployment-time filter)."""
|
||||
|
||||
def test_none_allows_all(self):
|
||||
assert is_in_target_modules("model.layers.0.self_attn.o_proj", None)
|
||||
|
||||
def test_suffix_in_target(self):
|
||||
assert is_in_target_modules(
|
||||
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
|
||||
)
|
||||
|
||||
def test_suffix_not_in_target(self):
|
||||
assert not is_in_target_modules(
|
||||
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
|
||||
)
|
||||
|
||||
def test_empty_target_modules(self):
|
||||
assert not is_in_target_modules("model.layers.0.self_attn.o_proj", [])
|
||||
|
||||
def test_exact_name_match(self):
|
||||
assert is_in_target_modules("dense1", ["dense1", "dense2"])
|
||||
|
||||
def test_exact_name_no_match(self):
|
||||
assert not is_in_target_modules("dense3", ["dense1", "dense2"])
|
||||
Reference in New Issue
Block a user