[Misc][LoRA] Add --lora-target-modules to restrict LoRA to specific modules (#34984)

Signed-off-by: Bhoomit Vasani <bhoomit.2010@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Bhoomit
2026-03-17 07:36:41 -07:00
committed by GitHub
parent ecfcdd2ce4
commit 3717a4dd47
9 changed files with 404 additions and 10 deletions

View File

@@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic
torch.testing.assert_close(
packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b
)
def _test_target_modules(
model,
target_modules: list[str] | None,
device: str,
expected_lora: list[tuple[str, type]],
expected_no_lora: list[tuple[str, type]],
):
"""Create a LoRAModelManager and assert which modules have LoRA applied."""
LoRAModelManager(
model,
2,
2,
2,
LoRAConfig(
max_lora_rank=8,
max_cpu_loras=2,
max_loras=2,
lora_dtype=DEFAULT_DTYPE,
target_modules=target_modules,
),
device=device,
)
for module_path, lora_cls in expected_lora:
assert isinstance(model.get_submodule(module_path), lora_cls)
for module_path, lora_cls in expected_no_lora:
assert not isinstance(model.get_submodule(module_path), lora_cls)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device):
"""Test that target_modules config restricts which modules get LoRA applied."""
_test_target_modules(
dummy_model,
["dense1"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
],
expected_no_lora=[
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device):
"""Test that multiple target_modules work correctly."""
_test_target_modules(
dummy_model,
["dense1", "dense2"],
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)
@pytest.mark.parametrize("device", DEVICES)
def test_target_modules_none_uses_all(
default_vllm_config, dist_init, dummy_model, device
):
"""Test that target_modules=None uses all supported modules."""
_test_target_modules(
dummy_model,
None,
device,
expected_lora=[
("dense1", ColumnParallelLinearWithLoRA),
("layer1.dense1", ColumnParallelLinearWithLoRA),
("dense2", RowParallelLinearWithLoRA),
("layer1.dense2", RowParallelLinearWithLoRA),
],
expected_no_lora=[],
)
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_unsupported_modules(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
not in the model's supported LoRA target modules."""
from unittest.mock import patch
import vllm.lora.worker_manager as wm_module
lora_config = LoRAConfig(
max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE
)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)
# Patch from_local_checkpoint to inject an unsupported module
original_from_checkpoint = LoRAModel.from_local_checkpoint
def patched_from_checkpoint(*args, **kwargs):
lora = original_from_checkpoint(*args, **kwargs)
lora.loras["unsupported_module"] = LoRALayerWeights(
module_name="unsupported_module",
rank=8,
lora_alpha=16,
lora_a=torch.randn(8, 10),
lora_b=torch.randn(10, 8),
)
return lora
lora_request = LoRARequest("test", 1, dummy_lora_files)
with (
patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint),
patch.object(wm_module.logger, "warning_once") as mock_warning,
):
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
found = any("unsupported_module" in str(call) for call in warning_args)
assert found, (
f"Expected warning about 'unsupported_module', got: {warning_args}"
)
@pytest.mark.parametrize("device", DEVICES)
def test_load_adapter_warns_on_target_modules_restriction(
default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path
):
"""Test that _load_adapter warns when a LoRA adapter contains modules
excluded by the deployment-time target_modules restriction."""
from unittest.mock import patch
import vllm.lora.worker_manager as wm_module
# Restrict to only dense2 — adapter has dense1 which will be excluded
lora_config = LoRAConfig(
max_lora_rank=8,
max_cpu_loras=4,
max_loras=4,
lora_dtype=DEFAULT_DTYPE,
target_modules=["dense2"],
)
dummy_lora_files = f"{tmp_path}/lora_adapter"
os.makedirs(dummy_lora_files, exist_ok=True)
create_peft_lora(
dummy_model_gate_up,
save_dir=dummy_lora_files,
target_modules=["layer1.dense1", "dense2"],
lora_dtype=DEFAULT_DTYPE,
)
model_config = ModelConfig(max_model_len=16)
vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config)
vllm_config.scheduler_config.max_num_seqs = 4
vllm_config.scheduler_config.max_num_batched_tokens = 2
worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES)
worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size
worker_manager.create_lora_manager(dummy_model_gate_up)
lora_request = LoRARequest("test", 1, dummy_lora_files)
with patch.object(wm_module.logger, "warning_once") as mock_warning:
worker_manager._load_adapter(lora_request)
warning_args = mock_warning.call_args_list
# dense1 is supported by the model but excluded by target_modules
found = any("target_modules" in str(call) for call in warning_args)
assert found, (
f"Expected warning about target_modules restriction, got: {warning_args}"
)

View File

@@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.lora.utils import is_in_target_modules, is_supported_lora_module
class TestIsSupportedLoraModule:
"""Tests for is_supported_lora_module (model-definition check)."""
def test_suffix_match(self):
assert is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)
def test_no_match(self):
assert not is_supported_lora_module(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)
def test_exact_match(self):
assert is_supported_lora_module("o_proj", ["o_proj"])
def test_regex_suffix_matching(self):
"""Regex anchors to end — partial suffix should not match."""
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", ["proj"])
def test_empty_supported_modules(self):
assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", [])
def test_multiple_supported_modules(self):
supported = ["q_proj", "k_proj", "v_proj", "o_proj"]
assert is_supported_lora_module("model.layers.0.self_attn.v_proj", supported)
assert not is_supported_lora_module("model.layers.0.mlp.gate_proj", supported)
class TestIsInTargetModules:
"""Tests for is_in_target_modules (deployment-time filter)."""
def test_none_allows_all(self):
assert is_in_target_modules("model.layers.0.self_attn.o_proj", None)
def test_suffix_in_target(self):
assert is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"]
)
def test_suffix_not_in_target(self):
assert not is_in_target_modules(
"model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"]
)
def test_empty_target_modules(self):
assert not is_in_target_modules("model.layers.0.self_attn.o_proj", [])
def test_exact_name_match(self):
assert is_in_target_modules("dense1", ["dense1", "dense2"])
def test_exact_name_no_match(self):
assert not is_in_target_modules("dense3", ["dense1", "dense2"])