diff --git a/docs/features/lora.md b/docs/features/lora.md index cf868eb14..2e7b36545 100644 --- a/docs/features/lora.md +++ b/docs/features/lora.md @@ -389,3 +389,17 @@ vllm serve model --enable-lora --max-lora-rank 64 # Bad: unnecessarily high, wastes memory vllm serve model --enable-lora --max-lora-rank 256 ``` + +### Restricting LoRA to Specific Modules + +The `--lora-target-modules` parameter allows you to restrict which model modules have LoRA applied at deployment time. This is useful for performance tuning when you only need LoRA on specific layers: + +```bash +# Apply LoRA only to output projection layers +vllm serve model --enable-lora --lora-target-modules o_proj + +# Apply LoRA to multiple specific modules +vllm serve model --enable-lora --lora-target-modules o_proj qkv_proj down_proj +``` + +When `--lora-target-modules` is not specified, LoRA will be applied to all supported modules in the model. This parameter accepts module suffixes (the last component of the module name), such as `o_proj`, `qkv_proj`, `gate_proj`, etc. diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index ccf145a0c..58dd328b3 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -291,3 +291,32 @@ def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises): else: with pytest.raises(raises): vllm_parser.parse_args(args=args) + + +### Tests for LoRA target modules parsing +def test_lora_target_modules_single(serve_parser): + """Test parsing single lora-target-modules argument""" + args = serve_parser.parse_args( + args=["--enable-lora", "--lora-target-modules", "o_proj"] + ) + assert args.lora_target_modules == ["o_proj"] + + +def test_lora_target_modules_multiple(serve_parser): + """Test parsing multiple lora-target-modules arguments""" + args = serve_parser.parse_args( + args=[ + "--enable-lora", + "--lora-target-modules", + "o_proj", + "qkv_proj", + "down_proj", + ] + ) + assert args.lora_target_modules == ["o_proj", "qkv_proj", "down_proj"] + + +def test_lora_target_modules_default_none(serve_parser): + """Test that lora-target-modules defaults to None""" + args = serve_parser.parse_args(args=[]) + assert args.lora_target_modules is None diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index d2a7cd155..e7addab11 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -711,3 +711,192 @@ def test_packed_loras(default_vllm_config, dist_init, dummy_model_gate_up, devic torch.testing.assert_close( packed_lora1.lora_b[1], model_lora_clone1.get_lora("up_proj").lora_b ) + + +def _test_target_modules( + model, + target_modules: list[str] | None, + device: str, + expected_lora: list[tuple[str, type]], + expected_no_lora: list[tuple[str, type]], +): + """Create a LoRAModelManager and assert which modules have LoRA applied.""" + LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig( + max_lora_rank=8, + max_cpu_loras=2, + max_loras=2, + lora_dtype=DEFAULT_DTYPE, + target_modules=target_modules, + ), + device=device, + ) + for module_path, lora_cls in expected_lora: + assert isinstance(model.get_submodule(module_path), lora_cls) + for module_path, lora_cls in expected_no_lora: + assert not isinstance(model.get_submodule(module_path), lora_cls) + + +@pytest.mark.parametrize("device", DEVICES) +def test_target_modules_config(default_vllm_config, dist_init, dummy_model, device): + """Test that target_modules config restricts which modules get LoRA applied.""" + _test_target_modules( + dummy_model, + ["dense1"], + device, + expected_lora=[ + ("dense1", ColumnParallelLinearWithLoRA), + ("layer1.dense1", ColumnParallelLinearWithLoRA), + ], + expected_no_lora=[ + ("dense2", RowParallelLinearWithLoRA), + ("layer1.dense2", RowParallelLinearWithLoRA), + ], + ) + + +@pytest.mark.parametrize("device", DEVICES) +def test_target_modules_multiple(default_vllm_config, dist_init, dummy_model, device): + """Test that multiple target_modules work correctly.""" + _test_target_modules( + dummy_model, + ["dense1", "dense2"], + device, + expected_lora=[ + ("dense1", ColumnParallelLinearWithLoRA), + ("layer1.dense1", ColumnParallelLinearWithLoRA), + ("dense2", RowParallelLinearWithLoRA), + ("layer1.dense2", RowParallelLinearWithLoRA), + ], + expected_no_lora=[], + ) + + +@pytest.mark.parametrize("device", DEVICES) +def test_target_modules_none_uses_all( + default_vllm_config, dist_init, dummy_model, device +): + """Test that target_modules=None uses all supported modules.""" + _test_target_modules( + dummy_model, + None, + device, + expected_lora=[ + ("dense1", ColumnParallelLinearWithLoRA), + ("layer1.dense1", ColumnParallelLinearWithLoRA), + ("dense2", RowParallelLinearWithLoRA), + ("layer1.dense2", RowParallelLinearWithLoRA), + ], + expected_no_lora=[], + ) + + +@pytest.mark.parametrize("device", DEVICES) +def test_load_adapter_warns_on_unsupported_modules( + default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path +): + """Test that _load_adapter warns when a LoRA adapter contains modules + not in the model's supported LoRA target modules.""" + from unittest.mock import patch + + import vllm.lora.worker_manager as wm_module + + lora_config = LoRAConfig( + max_lora_rank=8, max_cpu_loras=4, max_loras=4, lora_dtype=DEFAULT_DTYPE + ) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model_gate_up, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) + + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) + vllm_config.scheduler_config.max_num_seqs = 4 + vllm_config.scheduler_config.max_num_batched_tokens = 2 + + worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES) + worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size + worker_manager.create_lora_manager(dummy_model_gate_up) + + # Patch from_local_checkpoint to inject an unsupported module + original_from_checkpoint = LoRAModel.from_local_checkpoint + + def patched_from_checkpoint(*args, **kwargs): + lora = original_from_checkpoint(*args, **kwargs) + lora.loras["unsupported_module"] = LoRALayerWeights( + module_name="unsupported_module", + rank=8, + lora_alpha=16, + lora_a=torch.randn(8, 10), + lora_b=torch.randn(10, 8), + ) + return lora + + lora_request = LoRARequest("test", 1, dummy_lora_files) + with ( + patch.object(LoRAModel, "from_local_checkpoint", patched_from_checkpoint), + patch.object(wm_module.logger, "warning_once") as mock_warning, + ): + worker_manager._load_adapter(lora_request) + warning_args = mock_warning.call_args_list + found = any("unsupported_module" in str(call) for call in warning_args) + assert found, ( + f"Expected warning about 'unsupported_module', got: {warning_args}" + ) + + +@pytest.mark.parametrize("device", DEVICES) +def test_load_adapter_warns_on_target_modules_restriction( + default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path +): + """Test that _load_adapter warns when a LoRA adapter contains modules + excluded by the deployment-time target_modules restriction.""" + from unittest.mock import patch + + import vllm.lora.worker_manager as wm_module + + # Restrict to only dense2 — adapter has dense1 which will be excluded + lora_config = LoRAConfig( + max_lora_rank=8, + max_cpu_loras=4, + max_loras=4, + lora_dtype=DEFAULT_DTYPE, + target_modules=["dense2"], + ) + + dummy_lora_files = f"{tmp_path}/lora_adapter" + os.makedirs(dummy_lora_files, exist_ok=True) + create_peft_lora( + dummy_model_gate_up, + save_dir=dummy_lora_files, + target_modules=["layer1.dense1", "dense2"], + lora_dtype=DEFAULT_DTYPE, + ) + + model_config = ModelConfig(max_model_len=16) + vllm_config = VllmConfig(model_config=model_config, lora_config=lora_config) + vllm_config.scheduler_config.max_num_seqs = 4 + vllm_config.scheduler_config.max_num_batched_tokens = 2 + + worker_manager = WorkerLoRAManager(vllm_config, device, EMBEDDING_MODULES) + worker_manager.vocab_size = dummy_model_gate_up.unpadded_vocab_size + worker_manager.create_lora_manager(dummy_model_gate_up) + + lora_request = LoRARequest("test", 1, dummy_lora_files) + with patch.object(wm_module.logger, "warning_once") as mock_warning: + worker_manager._load_adapter(lora_request) + warning_args = mock_warning.call_args_list + # dense1 is supported by the model but excluded by target_modules + found = any("target_modules" in str(call) for call in warning_args) + assert found, ( + f"Expected warning about target_modules restriction, got: {warning_args}" + ) diff --git a/tests/lora/test_lora_utils.py b/tests/lora/test_lora_utils.py new file mode 100644 index 000000000..da66aa60b --- /dev/null +++ b/tests/lora/test_lora_utils.py @@ -0,0 +1,60 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + + +from vllm.lora.utils import is_in_target_modules, is_supported_lora_module + + +class TestIsSupportedLoraModule: + """Tests for is_supported_lora_module (model-definition check).""" + + def test_suffix_match(self): + assert is_supported_lora_module( + "model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"] + ) + + def test_no_match(self): + assert not is_supported_lora_module( + "model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"] + ) + + def test_exact_match(self): + assert is_supported_lora_module("o_proj", ["o_proj"]) + + def test_regex_suffix_matching(self): + """Regex anchors to end — partial suffix should not match.""" + assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", ["proj"]) + + def test_empty_supported_modules(self): + assert not is_supported_lora_module("model.layers.0.self_attn.o_proj", []) + + def test_multiple_supported_modules(self): + supported = ["q_proj", "k_proj", "v_proj", "o_proj"] + assert is_supported_lora_module("model.layers.0.self_attn.v_proj", supported) + assert not is_supported_lora_module("model.layers.0.mlp.gate_proj", supported) + + +class TestIsInTargetModules: + """Tests for is_in_target_modules (deployment-time filter).""" + + def test_none_allows_all(self): + assert is_in_target_modules("model.layers.0.self_attn.o_proj", None) + + def test_suffix_in_target(self): + assert is_in_target_modules( + "model.layers.0.self_attn.o_proj", ["o_proj", "q_proj"] + ) + + def test_suffix_not_in_target(self): + assert not is_in_target_modules( + "model.layers.0.self_attn.o_proj", ["q_proj", "k_proj"] + ) + + def test_empty_target_modules(self): + assert not is_in_target_modules("model.layers.0.self_attn.o_proj", []) + + def test_exact_name_match(self): + assert is_in_target_modules("dense1", ["dense1", "dense2"]) + + def test_exact_name_no_match(self): + assert not is_in_target_modules("dense3", ["dense1", "dense2"]) diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 0d310c87e..bfef0efa3 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -43,6 +43,10 @@ class LoRAConfig: `max_loras`.""" lora_dtype: torch.dtype | LoRADType = "auto" """Data type for LoRA. If auto, will default to base model dtype.""" + target_modules: list[str] | None = None + """Restrict LoRA to specific module suffixes (e.g., ["o_proj", "qkv_proj"]). + If None, all supported LoRA modules are used. This allows deployment-time + control over which modules have LoRA applied, useful for performance tuning.""" default_mm_loras: dict[str, str] | None = None """Dictionary mapping specific modalities to LoRA model paths; this field is only applicable to multimodal models and should be leveraged when a @@ -84,6 +88,10 @@ class LoRAConfig: factors.append(self.fully_sharded_loras) factors.append(self.lora_dtype) factors.append(self.enable_tower_connector_lora) + # target_modules affects which modules get LoRA applied + factors.append( + tuple(sorted(self.target_modules)) if self.target_modules else None + ) hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest() return hash_str diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 8fac21687..2c04c06e7 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -506,6 +506,7 @@ class EngineArgs: fully_sharded_loras: bool = LoRAConfig.fully_sharded_loras max_cpu_loras: int | None = LoRAConfig.max_cpu_loras lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype + lora_target_modules: list[str] | None = LoRAConfig.target_modules enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora specialize_active_lora: bool = LoRAConfig.specialize_active_lora @@ -1107,6 +1108,9 @@ class EngineArgs: lora_group.add_argument( "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] ) + lora_group.add_argument( + "--lora-target-modules", **lora_kwargs["target_modules"] + ) lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) lora_group.add_argument( "--specialize-active-lora", **lora_kwargs["specialize_active_lora"] @@ -1800,6 +1804,7 @@ class EngineArgs: default_mm_loras=self.default_mm_loras, fully_sharded_loras=self.fully_sharded_loras, lora_dtype=self.lora_dtype, + target_modules=self.lora_target_modules, enable_tower_connector_lora=self.enable_tower_connector_lora, specialize_active_lora=self.specialize_active_lora, max_cpu_loras=self.max_cpu_loras diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index a97c13022..12d6f719a 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -5,7 +5,6 @@ import math from collections.abc import Callable from typing import TypeVar -import regex as re import torch from torch import nn @@ -25,7 +24,9 @@ from vllm.lora.utils import ( from_layer, from_layer_logits_processor, get_supported_lora_modules, + is_in_target_modules, is_moe_model, + is_supported_lora_module, process_packed_modules_mapping, replace_submodule, ) @@ -541,14 +542,23 @@ class LoRAModelManager: model.loras[module_name] = lora return model - def _match_target_modules(self, module_name: str): - return any( - re.match( - r".*\.{target_module}$".format(target_module=target_module), module_name - ) - or target_module == module_name - for target_module in self.supported_lora_modules - ) + def _match_target_modules(self, module_name: str) -> bool: + """Check if a module should have LoRA applied. + + This method first checks if the module is in vLLM's supported LoRA + modules, then applies deployment-time restrictions based on + LoRAConfig.target_modules. + + Args: + module_name: Full dot-separated module name (e.g., + "model.layers.0.self_attn.o_proj") + + Returns: + True if LoRA should be applied to this module, False otherwise. + """ + if not is_supported_lora_module(module_name, self.supported_lora_modules): + return False + return is_in_target_modules(module_name, self.lora_config.target_modules) def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None: """ diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 6fef61dba..2349ace70 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -5,6 +5,7 @@ import os from typing import TYPE_CHECKING import huggingface_hub +import regex as re from huggingface_hub.utils import HfHubHTTPError, HFValidationError from torch import nn from transformers import PretrainedConfig @@ -226,6 +227,57 @@ def get_supported_lora_modules(model: nn.Module) -> list[str]: return list(supported_lora_modules) +def is_supported_lora_module( + module_name: str, + supported_lora_modules: list[str], +) -> bool: + """Check if a module is in the model's supported LoRA modules. + + Uses regex suffix matching against the model-defined supported modules + list (e.g., matching "model.layers.0.self_attn.o_proj" against + "o_proj"). + + Args: + module_name: Full dot-separated module name. + supported_lora_modules: List of module suffixes supported by the + model. + + Returns: + True if the module is supported, False otherwise. + """ + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name, + ) + or target_module == module_name + for target_module in supported_lora_modules + ) + + +def is_in_target_modules( + module_name: str, + target_modules: list[str] | None, +) -> bool: + """Check if a module passes the deployment-time target_modules filter. + + When target_modules is None (no restriction), all modules pass. + Otherwise, the module's suffix must be in the target_modules list. + + Args: + module_name: Full dot-separated module name. + target_modules: Optional deployment-time restriction list from + LoRAConfig.target_modules. + + Returns: + True if the module passes the filter, False otherwise. + """ + if target_modules is None: + return True + module_suffix = module_name.split(".")[-1] + return module_suffix in set(target_modules) + + def get_adapter_absolute_path(lora_path: str) -> str: """ Resolves the given lora_path to an absolute local path. diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index c5c0b7d33..9a0a13912 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -17,7 +17,11 @@ from vllm.lora.model_manager import ( ) from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest -from vllm.lora.utils import get_adapter_absolute_path +from vllm.lora.utils import ( + get_adapter_absolute_path, + is_in_target_modules, + is_supported_lora_module, +) logger = init_logger(__name__) @@ -142,6 +146,29 @@ class WorkerLoRAManager: skip_prefixes=lora_skip_prefixes, ) + # Warn about adapter modules that will be ignored. + target_modules = self.lora_config.target_modules + for module_name in lora.loras: + if not is_supported_lora_module(module_name, supported_lora_modules): + logger.warning_once( + "LoRA module '%s' in adapter '%s' is not in the " + "model's supported LoRA target modules [%s]. " + "These parameters will be ignored, which may " + "cause abnormal model behavior.", + module_name, + lora_request.lora_path, + ", ".join(sorted(supported_lora_modules)), + ) + elif not is_in_target_modules(module_name, target_modules): + logger.warning_once( + "LoRA module '%s' in adapter '%s' is not in the " + "deployment-time target_modules restriction [%s]." + " These parameters will be ignored.", + module_name, + lora_request.lora_path, + ", ".join(sorted(target_modules)), + ) + except FileNotFoundError as e: # FileNotFoundError should be raised if both # - No adapter found to download from huggingface (or in